Gaudi: clean cuda/rocm code in hpu backend, enable flat_hpu (#3113)

* clean cuda/rocm code in hpu backend, enable flat_hpu

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix TP in pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust block table in hpu to improve performance

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable all the model. not testet yet

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* use tensor cache in hpu graph to avoid replay issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add moe support, fix qwen/mistral/mixtral crash

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix phimoe issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* gpt_bigcode could also go pageattn

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable dbrx remove some unused code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* multi-modality initial PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* adjust warmup and enable vlm

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix incorrect output in qwen2 idefics if hpu graph is used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove unused quantization code and enable awq/gptq int4

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix gptq issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* enable fp8

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup prefill

remove model where pageattn is not used, set block table to None since it's not used

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add warmup_decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* warmup decode

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove block_tables and prefill_cache_indices which will lead to dynamic shape

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* missing gptj change...

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix some issue

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* remove torch.where to fix incorrect output in hpu graph model

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* match the latest vllm_extension ops

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2025-04-14 21:58:13 +08:00 committed by GitHub
parent 9a8d0462e1
commit d62c941c56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
91 changed files with 8804 additions and 11813 deletions

View File

@ -95,7 +95,7 @@ RUN cd server && \
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \ pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \ BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
pip install . --no-cache-dir pip install . --no-cache-dir
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
# Install router # Install router

View File

@ -16,15 +16,9 @@ app = typer.Typer()
class Quantization(str, Enum): class Quantization(str, Enum):
bitsandbytes = "bitsandbytes"
bitsandbytes_nf4 = "bitsandbytes-nf4"
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq" gptq = "gptq"
awq = "awq" awq = "awq"
eetq = "eetq"
exl2 = "exl2"
fp8 = "fp8" fp8 = "fp8"
marlin = "marlin"
class Dtype(str, Enum): class Dtype(str, Enum):
@ -105,6 +99,9 @@ def serve(
"bitsandbytes", "bitsandbytes",
"bitsandbytes-nf4", "bitsandbytes-nf4",
"bitsandbytes-fp4", "bitsandbytes-fp4",
"gptq",
"awq",
"fp8",
}: }:
raise RuntimeError( raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
@ -112,7 +109,7 @@ def serve(
logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype)) logger.info("CLI SHARDED = {} DTYPE = {}".format(sharded, dtype))
if sharded: if sharded and os.getenv("ATTENTION", "default") not in {"paged"}:
tgi_file = Path(__file__).resolve().parent / "tgi_service.py" tgi_file = Path(__file__).resolve().parent / "tgi_service.py"
num_shard = int(os.getenv("WORLD_SIZE", "1")) num_shard = int(os.getenv("WORLD_SIZE", "1"))
logger.info("CLI SHARDED = {}".format(num_shard)) logger.info("CLI SHARDED = {}".format(num_shard))

View File

@ -1,43 +1,28 @@
from text_generation_server.utils.import_utils import SYSTEM from .common import (
import os Seqlen,
HPUPagedAttentionMetadata,
trim_attn_metadata,
trim_seqlen_metadata,
)
from .common import Seqlen from .hpu import (
SUPPORTS_WINDOWING,
attention,
paged_attention,
)
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
from .cuda import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "ipex":
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
from .kv_cache import KVCache, get_kv_scales
__all__ = [ __all__ = [
"attention", "attention",
"get_kv_scales",
"paged_attention", "paged_attention",
"reshape_and_cache",
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING", "SUPPORTS_WINDOWING",
"KVCache",
"Seqlen", "Seqlen",
"HPUPagedAttentionMetadata",
"trim_seqlen_metadata",
"trim_attn_metadata",
] ]

View File

@ -1,72 +1,147 @@
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch import torch
from typing import Optional from typing import Optional, List, Dict
import collections
_TYPE_CACHE = {}
if ATTENTION in {"flashinfer", "flashdecoding"}: @dataclass
class HPUPagedAttentionMetadata:
"""Metadata for PagedAttention."""
@dataclass block_list: Optional[torch.Tensor]
class Seqlen: block_mapping: Optional[torch.Tensor]
input_lengths: torch.Tensor block_usage: Optional[torch.Tensor]
prefix_lengths: torch.Tensor block_scales: Optional[torch.Tensor]
cu_seqlen_q: Optional[torch.Tensor] block_groups: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor] attn_bias: Optional[torch.Tensor]
max_q: int
max_k: int
def __init__(
self,
input_lengths,
prefix_lengths,
cu_seqlen_q=None,
max_q=None,
max_k=None,
):
self.input_lengths = input_lengths
self.prefix_lengths = prefix_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
max_q = 1
else:
assert max_q is not None
assert max_k is not None
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral def subtuple(
# Although FA2 might not want the clamping obj: object,
# cu_seqlen_k[0] = 0 typename: str,
total = self.input_lengths + self.prefix_lengths to_copy: List[str],
torch.cumsum(total, -1, out=cu_seqlen_k[1:]) to_override: Optional[Dict[str, object]] = None,
):
if obj is None:
return None
if to_override is None:
to_override = {}
fields = set(to_copy) | set(to_override.keys())
if isinstance(obj, dict):
values = {key: obj[key] for key in fields if key in obj}
else:
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
if typename not in _TYPE_CACHE:
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
return _TYPE_CACHE[typename](**values)
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
self.max_q = max_q
self.max_k = max_k
def clamp(self, max): def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
# Flash decoding doesn't need to clamp # NOTE(kzawora): To anyone working on this in the future:
return self # Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
else: # Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
@dataclass # If you use primitive types here - they will get hashed based
class Seqlen: # on their value. You *will* get lots of excessive graph captures
input_lengths: torch.Tensor # (and an OOM eventually) if you decide to put something like
prefix_lengths: torch.Tensor # seq_len int here.
cu_seqlen_q: torch.Tensor # If you absolutely need a scalar, put it in a tensor. Tensors
max_q: int # get hashed using their metadata, not their values:
max_k: int # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedAttentionMetadata",
[
"block_list",
"block_mapping",
"block_usage",
"block_scales",
"block_groups",
"attn_bias",
],
)
return attention_metadata
def clamp(self, max):
if SYSTEM == "rocm": @dataclass
return self class Seqlen:
raise NotImplementedError("Not implemented seqlen for paged") input_lengths: torch.Tensor
return Seqlen(torch.clamp(self.input_lengths, max=max)) cache_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]
def __init__(
self,
input_lengths,
cache_lengths,
cu_seqlen_q=None,
):
self.input_lengths = input_lengths
self.cache_lengths = cache_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
if cu_seqlen_q is None:
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
total = self.input_lengths + self.cache_lengths
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k
def clamp(self, max):
# Flash decoding doesn't need to clamp
return self
def trim_seqlen_metadata(metadata: Seqlen) -> object:
# NOTE(kzawora): To anyone working on this in the future:
# Trimming metadata is required when using HPUGraphs.
# Attention metadata is going to be hashed by PT bridge, and
# appropriate HPUGraphs will be matched based on all inputs' hash.
# Before you put more keys in here, make sure you know their
# value type and make sure you know how it's going to be hashed.
# You can find that information in input_hash function
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
# If you use primitive types here - they will get hashed based
# on their value. You *will* get lots of excessive graph captures
# (and an OOM eventually) if you decide to put something like
# seq_len int here.
# If you absolutely need a scalar, put it in a tensor. Tensors
# get hashed using their metadata, not their values:
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(
metadata,
"TrimmedSeqlen",
[
"input_lengths",
"cache_lengths",
"cu_seqlen_q",
"cu_seqlen_k",
],
)
return attention_metadata

View File

@ -1,357 +0,0 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512
try:
from vllm._C import cache_ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# value_cache => [num_blocks, num_heads, head_size, block_size]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import decode_state
return decode_state.get().forward(
query.contiguous(),
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
)
elif ATTENTION == "flashdecoding":
max_q = 1
max_k = max_s
import flash_attn_2_cuda
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
if softcap is None:
softcap = 0.0
out = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None, # pad_k
None,
block_tables,
None,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
softcap,
False, # return softmax
None, # generator
)
return out[0]
else:
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
input_lengths = seqlen.input_lengths
from vllm._C import ops
out = torch.empty_like(query)
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out
try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda
V2 = True
except ImportError:
try:
import flash_attn_cuda
V2 = False
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = V2
if ATTENTION == "flashinfer":
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
from text_generation_server.layers.attention.flashinfer import (
prefill_with_paged_kv_state,
)
return prefill_with_paged_kv_state.get().forward(
q.contiguous(),
causal=causal,
paged_kv_cache=(key_cache, value_cache),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
window_left=window_size_left,
)
elif V2:
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
):
out = torch.empty_like(q)
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
None,
block_tables,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
else:
def attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap=None,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
)
if softcap is not None:
raise NotImplementedError("softcap is only available with flash attn v2")
# Flash attention v1 requires q, k and v to have the same number of heads
if k.shape[1] != q.shape[1]:
# MQA expand
if k.shape[1] == 1:
k = k.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = k.shape
k = (
k.unsqueeze(2)
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
if v.shape[1] != q.shape[1]:
# MQA expand
if v.shape[1] == 1:
v = v.expand(-1, q.shape[1], -1)
# Grouped attention reshape
else:
original_shape = v.shape
v = (
v.unsqueeze(2)
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
.reshape(original_shape[0], -1, original_shape[2])
)
out = torch.empty_like(q)
flash_attn_cuda.fwd(
q,
k,
v,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
False,
0,
None,
)
return out
# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2

View File

@ -1,813 +0,0 @@
#!/usr/bin/env python
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
(https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd with causal masking
2) Any sequence lengths without padding (currently fwd kernel only)
3) Support for different sequence lengths for q and k
4) Nested tensor API currently does not support dropout or bias.
Not currently supported:
1) Non power of two head dims
"""
import torch
import triton
import triton.language as tl
torch_dtype: tl.constexpr = torch.float16
@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y
@triton.jit
def max_fn(x, y):
return tl.math.max(x, y)
@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]
@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(
philox_seed, philox_offset, dropout_p, m, n, stride
).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)
@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep
@triton.jit
def load_fn(block_ptr, first, second, pad):
if first and second:
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
elif first:
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
elif second:
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
else:
tensor = tl.load(block_ptr)
return tensor
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
actual_seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr,
OFFS_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
MASK_STEPS: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
PADDED_HEAD: tl.constexpr,
):
# loop over k, v, and update accumulator
for start_n in range(block_min, block_max, BLOCK_N):
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
k = load_fn(
K_block_ptr,
PADDED_HEAD,
MASK_STEPS and (n_extra_tokens != 0),
"zero",
)
if PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
# We start from end of seqlen_k so only the first iteration would need
# to be checked for padding if it is not a multiple of block_n
# TODO: This can be optimized to only be true for the padded block.
if MASK_STEPS: # noqa: SIM102
# If this is the last block / iteration, we want to
# mask if the sequence length is not a multiple of block size
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
# if not is_modulo_mn. last step might get wasted but that is okay.
# check if this masking works for that case.
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))
if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
qk = tl.where(causal_mask, qk, float("-inf"))
# -- compute qk ----
qk += tl.dot(q, k)
if bias_ptr is not None:
bias = load_fn(
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
)
# While bias is added after multiplying qk with sm_scale, our
# optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += bias * 1.44269504089
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
philox_offset = (
batch_philox_offset
+ start_m * BLOCK_M * actual_seqlen_k
+ start_n
- BLOCK_N
)
keep = dropout_mask(
philox_seed,
philox_offset,
dropout_p,
BLOCK_M,
BLOCK_N,
actual_seqlen_k,
)
if RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
)
p = tl.where(keep, p, 0.0)
elif RETURN_ENCODED_SOFTMAX:
tl.store(
encoded_softmax_block_ptr,
p.to(encoded_softmax_block_ptr.type.element_ty),
)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not PRE_LOAD_V:
v = load_fn(
V_block_ptr,
MASK_STEPS and (n_extra_tokens != 0),
PADDED_HEAD,
"zero",
)
# -- update m_i and l_i
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, BLOCK_N)
)
return acc, l_i, m_i
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 64,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 256,
"BLOCK_N": 128,
"waves_per_eu": 2,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": True,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 3,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 64,
"BLOCK_N": 64,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
triton.Config(
{
"BLOCK_M": 32,
"BLOCK_N": 32,
"waves_per_eu": 4,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=8,
),
# TODO: This config fails with head_size not pow2 with data mismatches.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config(
{
"BLOCK_M": 16,
"BLOCK_N": 16,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
triton.Config(
{
"BLOCK_M": 128,
"BLOCK_N": 64,
"waves_per_eu": 1,
"PRE_LOAD_V": False,
},
num_stages=1,
num_warps=4,
),
],
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if VARLEN:
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
# small for all start_m so for those we return early.
if start_m * BLOCK_M > seqlen_q:
return
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
else:
cu_seqlens_q_start = 0
cu_seqlens_k_start = 0
seqlen_q = MAX_SEQLENS_Q
seqlen_k = MAX_SEQLENS_K
# Now we compute whether we need to exit early due to causal masking.
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
# are completely masked, resulting in 0s written to the output, and
# inf written to LSE. We don't need to do any GEMMs in this case.
# This block of code determines what N is, and if this WG is operating
# on those M rows.
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
if IS_CAUSAL:
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
# If seqlen_q != seqlen_k, attn scores are rectangular which means
# the causal mask boundary is bottom right aligned, and ends at either
# the top edge (seqlen_q < seqlen_k) or left edge.
# This captures the decrease in n_blocks if we have a rectangular attn
# matrix
n_blocks_seqlen = cdiv_fn(
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
)
# This is what adjusts the block_max for the current WG, only
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
n_blocks = min(n_blocks, n_blocks_seqlen)
# If we have no blocks after adjusting for seqlen deltas, this WG is
# part of the blocks that are all 0. We exit early.
if n_blocks <= 0:
o_offset = (
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0
# for these masked blocks.
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
# tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here?
return
# If MQA / GQA, set the K and V head offsets appropriately.
GROUP_SIZE: tl.constexpr = HQ // HK
if GROUP_SIZE != 1:
off_h_k = off_h_q // GROUP_SIZE
else:
off_h_k = off_h_q
n_extra_tokens = 0
if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
# Compute pointers for all the tensors used in this kernel.
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
if BIAS_TYPE != 0:
bias_ptr = tl.make_block_ptr(
base=bias + off_h_q * stride_bh,
shape=(seqlen_q, seqlen_k),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
if ENABLE_DROPOUT:
batch_philox_offset = (
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
)
else:
batch_philox_offset = 0
# We can ask to return the dropout mask without actually doing any dropout.
# In this case, we return an invalid pointer so indicate the mask is not i
# valid.
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.make_block_ptr(
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
shape=(seqlen_q, seqlen_k),
strides=(seqlen_k, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
encoded_softmax_block_ptr = 0
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
if IS_CAUSAL:
# There are always at least BLOCK_M // BLOCK_N masked blocks.
# Additionally there might be one more due to dissimilar seqlens.
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
else:
# Padding on Q does not need to be masked in the FA loop.
masked_blocks = padded_block_k
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
# block. In this case we might exceed n_blocks so pick the min.
masked_blocks = min(masked_blocks, n_blocks)
n_full_blocks = n_blocks - masked_blocks
block_min = 0
block_max = n_blocks * BLOCK_N
# Compute for full blocks. Here we set causal to false regardless of its
# value because there is no masking. Similarly we do not need padding.
if n_full_blocks > 0:
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min,
block_max,
0,
0,
0,
bias_ptr,
# IS_CAUSAL, ....
False,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
False,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
block_min = block_max
block_max = n_blocks * BLOCK_N
tl.debug_barrier()
# Remaining blocks, if any, are full / not masked.
if masked_blocks > 0:
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
if bias_ptr is not None:
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
if RETURN_ENCODED_SOFTMAX:
encoded_softmax_block_ptr = tl.advance(
encoded_softmax_block_ptr, (0, n_full_blocks)
)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
K_block_ptr,
V_block_ptr,
start_m,
seqlen_k,
dropout_p,
philox_seed,
batch_philox_offset,
encoded_softmax_block_ptr,
block_min,
block_max,
offs_n_causal,
masked_blocks,
n_extra_tokens,
bias_ptr,
IS_CAUSAL,
BLOCK_M,
BLOCK_DMODEL,
BLOCK_N,
offs_m,
offs_n,
# _, MASK_STEPS, ...
PRE_LOAD_V,
True,
ENABLE_DROPOUT,
RETURN_ENCODED_SOFTMAX,
PADDED_HEAD,
)
# epilogue
acc = acc / l_i[:, None]
if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
# then we have one block with a row of all NaNs which come from computing
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
# and store 0s where there are NaNs as these rows should've been zeroed out.
end_m_idx = (start_m + 1) * BLOCK_M
start_m_idx = start_m * BLOCK_M
causal_start_idx = seqlen_q - seqlen_k
acc = acc.to(Out.type.element_ty)
if IS_CAUSAL: # noqa: SIM102
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
out_mask_boundary = tl.full(
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
)
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size = end_m_idx - seqlen_q
# if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
# else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# Need boundary check on this to make sure the padding from the
# Q and KV tensors in both dims are not part of what we store back.
# TODO: Do the boundary check optionally.
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
def check_args(
q,
k,
v,
o,
varlen=True,
max_seqlens=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
):
assert q.dim() == k.dim() and q.dim() == v.dim()
if varlen:
assert q.dim() == 3
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
assert cu_seqlens_q is not None
assert cu_seqlens_k is not None
assert len(cu_seqlens_q) == len(cu_seqlens_k)
else:
assert q.dim() == 4
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
assert max_seqlens > 0
assert k.shape == v.shape
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported
assert head_size <= 128
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
class _attention(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
o,
cu_seqlens_q,
cu_seqlens_k,
max_seqlens_q,
max_seqlens_k,
causal=False,
sm_scale=1.0,
bias=None,
):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
check_args(
q,
k,
v,
o,
varlen=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
)
if True: # varlen
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = len(cu_seqlens_q) - 1
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
padded_d_model = max(padded_d_model, 16)
def grid(META):
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
encoded_softmax = None
# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
if bias is not None:
bias_strides = (
bias.stride(0),
bias.stride(1),
bias.stride(2),
bias.stride(3),
)
else:
bias_strides = (0, 0, 0, 0)
attn_fwd[grid](
q,
k,
v,
bias,
sm_scale,
None,
o,
*q_strides,
*k_strides,
*v_strides,
*o_strides,
*bias_strides,
cu_seqlens_q,
cu_seqlens_k,
dropout_p=0.0,
philox_seed=philox_seed,
philox_offset_base=philox_offset,
encoded_softmax=encoded_softmax,
HQ=nheads_q,
HK=nheads_k,
ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k,
IS_CAUSAL=causal,
VARLEN=True,
BLOCK_DMODEL=padded_d_model,
BIAS_TYPE=0 if bias is None else 1,
ENABLE_DROPOUT=False,
RETURN_ENCODED_SOFTMAX=False,
)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
ctx.causal = causal
ctx.dropout_p = 0.0
ctx.philox_seed = philox_seed
ctx.philox_offset = philox_offset
ctx.encoded_softmax = encoded_softmax
ctx.return_encoded_softmax = False
return o, encoded_softmax
triton_attention = _attention.apply

View File

@ -1,251 +0,0 @@
from typing import Optional
from contextvars import ContextVar
from contextlib import contextmanager
import flashinfer
import torch
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
"prefill_state"
)
prefill_with_paged_kv_state: ContextVar[
flashinfer.BatchPrefillWithPagedKVCacheWrapper
] = ContextVar("prefill_with_paged_kv_state")
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
"decode_state"
)
workspace: Optional[torch.Tensor] = None
def get_workspace(device):
"""Get shared flashinfer workspace."""
global workspace
if workspace is None:
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
return workspace
def create_prefill_with_paged_kv_state(
*,
device: torch.device,
):
"""Create a prefill state that uses the KV cache."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_with_paged_kv_state(
*,
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
block_tables: torch.Tensor,
cu_seqlens: torch.Tensor,
input_lengths: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
if page_size == 1:
last_page_len = torch.ones(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
else:
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = prefill_with_paged_kv_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
paged_kv_indptr=indptr,
paged_kv_indices=block_tables,
paged_kv_last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
page_size=page_size,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_with_paged_kv_state.reset(token)
def create_prefill_state(
*,
device: torch.device,
):
"""Create a prefill state."""
workspace_buffer = get_workspace(device)
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
)
@contextmanager
def use_prefill_state(
*,
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
cu_seqlens: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer prefill state to the given
`state` and parameters. This state will be used by all calls to the
`attention` function while the context manager is active.
"""
token = prefill_state.set(state)
try:
state.begin_forward(
qo_indptr=cu_seqlens,
kv_indptr=cu_seqlens,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
prefill_state.reset(token)
def create_decode_state(
*,
device: torch.device,
num_heads: int,
num_kv_heads: int,
):
"""Create a decode state."""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=False,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
def create_decode_state_cuda_graphs(
*,
device: torch.device,
block_tables: torch.Tensor,
block_tables_ptr: torch.Tensor,
last_page_len: torch.Tensor,
num_heads: int,
num_kv_heads: int,
):
"""
Create a decode state for use with CUDA Graphs. `block_tables`,
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
therefore stored as part of the state.
"""
workspace_buffer = get_workspace(device)
num_groups = num_heads // num_kv_heads
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout="NHD",
use_cuda_graph=True,
paged_kv_indices_buffer=block_tables,
paged_kv_indptr_buffer=block_tables_ptr,
paged_kv_last_page_len_buffer=last_page_len,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores=num_groups not in [1, 2, 4, 8],
)
@contextmanager
def use_decode_state(
*,
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
input_lengths: torch.Tensor,
block_tables: torch.Tensor,
num_heads: int,
num_kv_heads: int,
head_size: int,
page_size: int,
dtype: torch.dtype,
window_left: int,
):
"""
Context manager to set the active flashinfer decoding state to the given
`state` and parameters. This state will be used by all calls to the
`paged_attention` function while the context manager is active.
"""
indptr = torch.zeros(
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
)
# Round up to page size and then calculate the cumulative sum to get
# the indices into the block table.
torch.add(input_lengths, page_size - 1, out=indptr[1:])
indptr[1:].div_(page_size, rounding_mode="floor")
indptr[1:].cumsum_(-1)
# Get the lengths of the last page in a block.
last_page_len = torch.empty(
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
)
torch.sub(input_lengths, 1, out=last_page_len)
last_page_len.remainder_(page_size)
last_page_len += 1
token = decode_state.set(state)
try:
state.begin_forward(
indptr=indptr,
indices=block_tables,
last_page_len=last_page_len,
num_qo_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_size,
page_size=page_size,
data_type=dtype,
q_data_type=dtype,
window_left=window_left,
)
yield
finally:
state.end_forward()
if token is not None:
decode_state.reset(token)

View File

@ -0,0 +1,95 @@
import torch
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from typing import Optional
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
from vllm_hpu_extension import ops
from vllm_hpu_extension.utils import Matmul
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import os
SUPPORTS_WINDOWING = False
def fetch_from_cache(cache, blocks):
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
return cache[: blocks.size(0)]
else:
return cache.index_select(0, blocks)
def attention(
*,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: KVCache,
kv_scales: KVScales,
seqlen: Seqlen,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
bs = seqlen.input_lengths.shape[0]
_, head_num, head_size = query.shape
_, kv_head_num, head_size = key.shape
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=softmax_scale,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=seqlen.input_lengths,
padding_side="left",
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
return attn_output
def paged_attention(
query: torch.Tensor,
kv_cache: KVCache,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
seqlen: Seqlen,
*,
kv_scales: KVScales,
softcap: Optional[float] = None,
hpu_attention_meta: HPUPagedAttentionMetadata,
):
batch_size, head_num, head_size = query.shape
output = ops.flat_pa(
query=query.view(batch_size, 1, head_num * head_size),
key_cache=kv_cache.key,
value_cache=kv_cache.value,
block_list=hpu_attention_meta.block_list,
block_mapping=hpu_attention_meta.block_mapping,
block_bias=hpu_attention_meta.attn_bias,
block_scales=hpu_attention_meta.block_scales,
block_groups=hpu_attention_meta.block_groups,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, head_num, head_size)
__all__ = [
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
]

View File

@ -1,82 +0,0 @@
import intel_extension_for_pytorch as ipex
import torch
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen
from typing import Optional
SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
ipex.llm.functional.varlen_attention(
q.contiguous() if q.device.type == "xpu" else q,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_q,
0.0,
softmax_scale,
False,
causal,
False,
None,
)
return out
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
out = torch.empty_like(query)
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
seqlen.input_lengths,
BLOCK_SIZE,
max_s,
None,
)
return out

View File

@ -0,0 +1,139 @@
from typing import Tuple
from dataclasses import dataclass, field
import torch
from text_generation_server.models.globals import BLOCK_SIZE
from text_generation_server.utils.weights import Weights
from vllm_hpu_extension import cache_ops
@dataclass
class KVScales:
"""
Key-value scales for FP8 KV cache.
This data class stores key and value scales both as a GPU tensor and
as a GPU float. This inconvenience is necessary because some functions
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
(e.g. flashinfer) take scales as a CPU scalar.
"""
key_scale: torch.Tensor
value_scale: torch.Tensor
key_scale_cpu: float = field(init=False)
value_scale_cpu: float = field(init=False)
def __post_init__(self):
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
raise ValueError("Key and value scales must be scalar tensors.")
self.key_scale_cpu = self.key_scale.item()
self.value_scale_cpu = self.value_scale.item()
class KVCache:
"""
Key-value cache for attention layers.
"""
kv_cache: Tuple[torch.Tensor, torch.Tensor]
def __init__(
self,
*,
num_blocks: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
):
"""Construct the key-value cache for a layer."""
## TODO FP8 kv cache support
self.kv_cache = (
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
torch.zeros(
(num_blocks, BLOCK_SIZE, num_heads, head_size),
dtype=dtype,
device=device,
),
)
@property
def dtype(self):
"""Get the data type of the cache."""
return self.kv_cache[0].dtype
@property
def key(self):
"""Get the key cache."""
return self.kv_cache[0]
@property
def value(self):
"""Get the value cache."""
return self.kv_cache[1]
def store(
self,
*,
key: torch.Tensor,
value: torch.Tensor,
slots: torch.Tensor,
kv_scales: KVScales,
):
"""Store the key and value at the given slots."""
## TODO FP8 kv cache support
key_cache = self.kv_cache[0]
value_cache = self.kv_cache[1]
paged_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slots,
kv_scales.key_scale_cpu,
kv_scales.value_scale_cpu,
)
def paged_reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
block_idx = slots // BLOCK_SIZE
block_offset = slots % BLOCK_SIZE
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
"""Load KV cache scales."""
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
value_scale = key_scale
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
f"{prefix}.v_scale"
):
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
elif weights.has_tensor(f"{prefix}.kv_scale"):
# Fall back to older more coarse-grained scale when available.
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
value_scale = key_scale
return KVScales(key_scale=key_scale, value_scale=value_scale)

View File

@ -1,308 +0,0 @@
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.utils.log import log_master
from loguru import logger
major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"
PREFILL_IN_KV_CACHE = False
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
try:
if use_rocm_custom_paged_attn:
from vllm._custom_C import paged_attention_custom
except ImportError as e:
log_master(
logger.info,
f"Custom Paged Attention not available. Complete error: {e}",
)
use_rocm_custom_paged_attn = False
try:
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION == "flashdecoding":
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if softcap is not None:
raise RuntimeError("Paged attention doesn't support softcapping")
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
use_rocm_custom_paged_attn
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)
if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = seqlen.input_lengths
out = torch.empty_like(query)
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
import vllm._custom_ops as ops
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
if not use_custom:
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)
return out
if ENGINE != "triton":
try:
import flash_attn_2_cuda
log_master(
logger.info,
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
)
except ImportError as e:
if major >= 8:
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
elif is_sm75:
raise ImportError(
"Flash Attention is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
else:
for idx in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(idx)
if "MI210" not in name and "MI250" not in name:
raise ImportError(
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
)
raise ImportError(
f"AMD GPU with ROCm capability {major} {minor} is not supported"
) from e
SUPPORTS_WINDOWING = False
if ENGINE == "ck":
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: float = 0.0,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)[0]
elif ENGINE == "triton":
from .flash_attn_triton import triton_attention
def attention(
q,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale: float,
window_size_left: int = -1,
causal: bool = True,
softcap: Optional[float] = None,
):
if softcap is not None:
raise NotImplementedError("softcap is only available with CK flash attn")
out = torch.empty_like(q)
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
key_cache,
value_cache,
out,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)
return output
else:
raise RuntimeError(f"Unknown attention engine {ENGINE}")

View File

@ -0,0 +1,3 @@
from .hpu import WQLinear
__all__ = ["WQLinear"]

View File

@ -0,0 +1,134 @@
from typing import Optional
import torch
import torch.nn as nn
try:
import habana_frameworks.torch.hpu # noqa: F401
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]
).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
else:
izeros = qzeros
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
iweights.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
if izeros is not None:
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def unpack_weight_and_zeros(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
return iweight, izeros
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros(
(normal.shape[0], normal.shape[1] // 32 * bits),
dtype=torch.int32,
device=input.device,
)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
self._preprocessing()
def _preprocessing(self):
device = self.qweight.device
weight, zeros = unpack_weight_and_zeros(
self.qweight.cpu(), self.qzeros.cpu(), self.w_bit
)
self.qweight = pack_tensor(weight).to(device)
self.qzeros = pack_tensor(zeros).to(device)
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])
weights = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
outputs = torch.matmul(x, weights)
outputs = outputs + self.bias if self.bias is not None else outputs
outputs = outputs.reshape(out_shape)
return outputs

View File

@ -1,49 +0,0 @@
# Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py
from typing import Optional
import torch
import torch.nn as nn
import awq_inference_engine # with CUDA kernels
# class ScaledActivation(nn.Module):
# def __init__(self, module, scales):
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
class WQLinear(nn.Module):
def __init__(
self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor]
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else self.in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert self.out_features % (32 // self.w_bit) == 0
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.bias = bias
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -1,43 +0,0 @@
from dataclasses import dataclass
import torch
from EETQ import quant_weights, w8_a16_gemm
from text_generation_server.utils.weights import UnquantizedWeight
@dataclass
class EETQWeight(UnquantizedWeight):
weight: torch.Tensor
def get_linear(self, bias: torch.Tensor):
try:
from text_generation_server.layers.eetq import EETQLinear
return EETQLinear(self.weight, bias)
except ImportError:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
class EETQLinear(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
device = weight.device
if weight.dtype != torch.float16:
weight = weight.to(dtype=torch.float16)
weight = torch.t(weight).contiguous().cpu()
weight, scale = quant_weights(weight, torch.int8, False)
self.weight = weight.cuda(device)
self.scale = scale.cuda(device)
self.bias = bias.cuda(device) if bias is not None else None
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = w8_a16_gemm(input, self.weight, self.scale)
output = output + self.bias if self.bias is not None else output
return output

View File

@ -1,100 +1,152 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Type, Union, List
import torch import torch
from dataclasses import dataclass
from typing import Optional, Union, List
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
Weight, Weight,
WeightsLoader, WeightsLoader,
UnquantizedWeight, UnquantizedWeight,
Weights, Weights,
) )
from text_generation_server.utils.log import log_master, log_once
import importlib.util from vllm_hpu_extension.ops import scaled_fp8_quant
from vllm_hpu_extension.scales import get_hpu_gaudi2_scale_factor, is_hpu_gaudi2
import habana_frameworks.torch.utils.experimental as htexp
w8a8_block_fp8_matmul = None
per_token_group_quant_fp8 = None
quant_dtype: torch.dtype = torch.float8_e4m3fn
FBGEMM_MM_AVAILABLE = False def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
FBGEMM_DYN_AVAILABLE = False
def is_fbgemm_gpu_available():
try:
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
except ModuleNotFoundError:
return False
if is_fbgemm_gpu_available():
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
FBGEMM_MM_AVAILABLE = major == 9
FBGEMM_DYN_AVAILABLE = major >= 8
else:
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
def get_fp8_linear() -> torch.nn.Module:
""" """
Return an FP8 linear `Module` that is compatible with the current system. Return an FP8 linear `Module` that is compatible with the current system.
""" """
if SYSTEM == "cuda":
major, _ = torch.cuda.get_device_capability()
if major == 8:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
return GPTQMarlinFP8Linear
# On other systems let Torch decide if the hardware supports FP8. # On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear return Fp8Linear
def fp8_quantize( def normalize_e4m3fn_to_native_float8(
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False weight: torch.Tensor,
): weight_scale: torch.Tensor,
if FBGEMM_DYN_AVAILABLE and not scalar: input_scale: Optional[torch.Tensor] = None,
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row( ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype return weight, weight_scale, input_scale
)
return qweight, scale
# weight, scale = quant_weights(weight, torch.int8, False)
finfo = torch.finfo(qdtype) def per_tensor_dequantize(
# Calculate the scale as dtype max divided by absmax tensor: torch.Tensor,
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound) inv_scale: Union[float, torch.Tensor],
# scale and clamp the tensor to bring it to dtype: torch.dtype = torch.float16,
# the representative range of float8 data type ) -> torch.Tensor:
# (as default cast is unsaturated) device = tensor.device
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max) dtype = torch.bfloat16
# Return both float8 data and the inverse scale (as float), if htexp._get_device_type() == htexp.synDeviceType.synDeviceGaudi2:
# as both required as inputs to torch._scaled_mm # dequant on cpu to avoid nan on gaudi2
qweight = qweight.to(qdtype) tensor = tensor.to("cpu")
scale = scale.float().reciprocal()
return qweight, scale fake_qweight = tensor.to(dtype).to(device)
dq_weight = fake_qweight * inv_scale
return dq_weight
def requantize_with_max_scale(
weight: torch.Tensor,
weight_scale: torch.Tensor,
logical_widths: int,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
if is_hpu_gaudi2():
max_w_scale = max_w_scale * get_hpu_gaudi2_scale_factor()
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
weight[start:end, :], weight_scale[idx], dtype
)
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
weight_dq, max_w_scale
)
start = end
return weight, max_w_scale_normalized
def fp8_quantize(
weight: torch.Tensor,
scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[torch.Tensor] = None,
qdtype: torch.dtype = torch.float8_e4m3fn,
scalar: bool = False,
):
"""
This function returns a reciprocal of the scale, so that a tensor can be unscaled
by multiplying it with the returned scale. If a scale is given through the `scale`
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
be used without modification).
"""
shape = weight.shape
qweight, scale = scaled_fp8_quant(
weight.reshape(-1, shape[-1]),
scale=scale,
scale_ub=scale_upper_bound,
# TODO: don't do this when we have to use the Torch kernel.
use_per_token_if_dynamic=not scalar,
)
return qweight.reshape(shape), scale
class HybridFP8UnquantLoader(WeightsLoader): class HybridFP8UnquantLoader(WeightsLoader):
"""Weight loader that loads FP8 and unquantized Torch tensors.""" """Weight loader that loads FP8 and unquantized Torch tensors."""
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool): def __init__(
self,
activation_scale_ub: Optional[float],
to_fp8: bool,
weight_block_size: Optional[List[int]] = None,
):
self.activation_scale_ub = activation_scale_ub self.activation_scale_ub = activation_scale_ub
self.to_fp8 = to_fp8 self.to_fp8 = to_fp8
self.weight_block_size = weight_block_size
def get_weights(self, weights: "Weights", prefix: str): def get_weights(self, weights: "Weights", prefix: str):
w = weights.get_tensor(f"{prefix}.weight") w = weights.get_tensor(f"{prefix}.weight")
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
# FP8 branch # FP8 branch
scale = ( scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
.reshape(-1) input_scale = None
.expand(w.shape[0]) if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
) )
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -116,6 +168,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_packed_sharded( scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", f"{prefix}.weight_scale",
@ -123,11 +176,29 @@ class HybridFP8UnquantLoader(WeightsLoader):
block_sizes=block_sizes, block_sizes=block_sizes,
to_dtype=False, to_dtype=False,
) )
scale = scale.reshape(-1).expand(w.shape[0])
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = weights.get_tensor(
f"{prefix}.input_scale", to_dtype=False
)
if input_scale.numel() > 1:
input_scale = weights.get_packed_sharded(
f"{prefix}.input_scale",
dim=0,
block_sizes=block_sizes,
to_dtype=False,
)
input_scale = input_scale.reshape(-1).max()
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -148,15 +219,48 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
if self.weight_block_size is not None:
scale = [
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
for p in prefixes
]
scale = torch.cat(scale, dim=dim)
scale = scale.to(weights.device)
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = [ scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape) _load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
for p, shape in zip(prefixes, shapes) for p, shape in zip(prefixes, shapes)
] ]
scale = torch.cat(scale, dim=0).reshape(-1) scale = torch.cat(scale, dim=0).reshape(-1)
input_scale = [
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
for p, shape in zip(prefixes, shapes)
if weights.has_tensor(f"{p}.input_scale")
]
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
input_scale = (
torch.cat(input_scale, dim=0).reshape(-1).max()
if len(input_scale) != 0
else None
)
logical_widths = [x[0] for x in shapes]
w, scale = requantize_with_max_scale(
w, scale.to(weights.device), logical_widths, weights.dtype
)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -169,14 +273,35 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = ( if self.weight_block_size is not None:
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) # XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
.reshape(-1) scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
.expand(w.shape[0])
return Fp8Weight(
weight=w,
weight_scale=scale,
activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype,
weight_block_size=self.weight_block_size,
)
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
input_scale = None
if weights.has_tensor(f"{prefix}.input_scale"):
input_scale = (
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
.reshape(-1)
.max()
)
logical_widths = [w.shape[0]]
w, scale = requantize_with_max_scale(
w, scale.unsqueeze(0), logical_widths, weights.dtype
) )
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
input_scale=input_scale,
activation_scale_ub=self.activation_scale_ub, activation_scale_ub=self.activation_scale_ub,
dtype=weights.dtype, dtype=weights.dtype,
) )
@ -191,83 +316,126 @@ class Fp8Weight(Weight):
weight: torch.Tensor weight: torch.Tensor
dtype: torch.dtype dtype: torch.dtype
weight_scale: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None
input_scale: Optional[torch.Tensor] = None
activation_scale_ub: Optional[float] = None activation_scale_ub: Optional[float] = None
force_w8a16: bool = False
weight_block_size: Optional[List[int]] = None
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.weight_scale is None: if self.weight_scale is None:
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype) return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
self.weight, bias, self.dtype
)
# This is not checked by the fbgemm kernels, but they require contiguous # This is not checked by the fbgemm kernels, but they require contiguous
# memory. Can be non-contiguous when we e.g. expand from scalars. # memory. Can be non-contiguous when we e.g. expand from scalars.
self.weight_scale = self.weight_scale.contiguous() self.weight_scale = self.weight_scale.contiguous()
return get_fp8_linear().from_fp8( return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype weight=self.weight,
scale=self.weight_scale,
dtype=self.dtype,
bias=bias,
input_scale=self.input_scale,
scale_upper_bound=self.activation_scale_ub,
weight_block_size=self.weight_block_size,
) )
class Fp8Linear(torch.nn.Module): class Fp8Linear(torch.nn.Module):
_device_identity_cache = {}
def __init__( def __init__(
self, self,
qweight, qweight: torch.Tensor,
scale, scale: torch.Tensor,
scale_upper_bound, dtype: torch.dtype,
bias, bias: Optional[torch.Tensor] = None,
dtype, input_scale: Optional[torch.Tensor] = None,
scale_upper_bound: Optional[float] = None,
weight_block_size: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if FBGEMM_MM_AVAILABLE:
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
self.dtype = dtype self.dtype = dtype
self.qweight = qweight self.qweight = qweight
self.scale = scale self.scale = scale.float()
self.scale_upper_bound = ( self.input_scale = input_scale.float() if input_scale is not None else None
torch.tensor( self.weight_block_size = weight_block_size
[scale_upper_bound], dtype=torch.float32, device=qweight.device self.scale_upper_bound = scale_upper_bound
)
if scale_upper_bound is not None
else None
)
self.bias = bias if bias is not None else None self.bias = bias if bias is not None else None
@classmethod @classmethod
def from_unquant(cls, weight, bias, dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE) qweight, scale = fp8_quantize(weight, scalar=True)
return cls( return cls(
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype qweight=qweight,
scale=scale,
dtype=dtype,
bias=bias,
input_scale=None,
scale_upper_bound=None,
) )
@classmethod @classmethod
def from_fp8(cls, weight, scale, input_scale, bias, dtype): def from_fp8(
if FBGEMM_DYN_AVAILABLE: cls,
# fbgemm needs float32 scales. weight: torch.Tensor,
scale = scale.float() scale: torch.Tensor,
dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> "Fp8Linear":
input_scale = kwargs.get("input_scale", None)
scale_upper_bound = kwargs.get("scale_upper_bound", None)
weight_block_size = kwargs.get("weight_block_size", None)
return cls( return cls(
qweight=weight, qweight=weight,
scale=scale, scale=scale,
scale_upper_bound=input_scale, input_scale=input_scale,
scale_upper_bound=scale_upper_bound,
bias=bias, bias=bias,
dtype=dtype, dtype=dtype,
weight_block_size=weight_block_size,
) )
def forward(self, input: torch.Tensor) -> torch.Tensor: @classmethod
if FBGEMM_MM_AVAILABLE: def get_shared_device_identity(cls, device):
qinput, scale = fp8_quantize( # Input scaling factors are no longer optional in _scaled_mm starting
input, scale_upper_bound=self.scale_upper_bound # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
) if device not in cls._device_identity_cache:
cls._device_identity_cache[device] = torch.ones(1, device=device)
return cls._device_identity_cache[device]
y = torch.ops.fbgemm.f8f8bf16_rowwise( def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.weight_block_size is not None:
# https://arxiv.org/pdf/2412.19437
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
# channels).
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
output = w8a8_block_fp8_matmul(
qinput, qinput,
self.qweight, self.qweight,
scale, scale,
self.scale, self.scale,
use_fast_accum=True, self.weight_block_size,
bias=self.bias, output_dtype=input.dtype,
) )
return y.to(self.dtype)
qinput, scale = fp8_quantize(input, scalar=True) if self.bias is not None:
output, _ = torch._scaled_mm( output = output + self.bias
return output.to(dtype=input.dtype)
qinput, scale = fp8_quantize(
input,
self.input_scale,
scale_upper_bound=self.scale_upper_bound,
scalar=True,
)
output = torch._scaled_mm(
qinput, qinput,
self.qweight.t(), self.qweight.t(),
out_dtype=self.dtype, out_dtype=self.dtype,
@ -275,11 +443,16 @@ class Fp8Linear(torch.nn.Module):
scale_b=self.scale, scale_b=self.scale,
bias=self.bias, bias=self.bias,
) )
if isinstance(output, tuple) and len(output) == 2:
output = output[0]
return output return output
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size): def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
scale = weights.get_tensor(prefix, to_dtype=False) scale = weights.get_tensor(prefix, to_dtype=False)
if scale.numel() > 1: if scale.numel() > 1:
scale = weights.get_sharded(prefix, dim=0, to_dtype=False) scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
return scale.reshape(-1).expand(shape[0]) return scale.reshape(-1)

View File

@ -1,14 +1,15 @@
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
from .hpu import QuantLinear
@dataclass @dataclass
class GPTQWeight(Weight): class GPTQWeight(Weight):
qweight: torch.Tensor qweight: torch.Tensor
@ -30,13 +31,8 @@ class GPTQWeight(Weight):
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
if self.use_awq_kernel: if self.use_awq_kernel:
if SYSTEM == "rocm":
raise NotImplementedError(
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
"to use Exllama/GPTQ kernels for AWQ inference."
)
try: try:
from text_generation_server.layers.awq.quantize.qmodule import WQLinear from text_generation_server.layers.awq.quantize import WQLinear
return WQLinear( return WQLinear(
w_bit=self.bits, w_bit=self.bits,
@ -50,18 +46,7 @@ class GPTQWeight(Weight):
raise NotImplementedError( raise NotImplementedError(
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly" "You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
) )
elif self.use_exllama:
try:
from text_generation_server.layers.gptq import ExllamaQuantLinear
except ImportError:
raise NotImplementedError(
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
)
return ExllamaQuantLinear(self, bias)
else: else:
from text_generation_server.layers.gptq.quant_linear import QuantLinear
return QuantLinear( return QuantLinear(
self.qweight, self.qweight,
self.qzeros, self.qzeros,
@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader):
else: else:
g_idx = None g_idx = None
from text_generation_server.layers.gptq import (
HAS_EXLLAMA,
CAN_EXLLAMA,
GPTQWeight,
)
if use_exllama:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
qzeros = weights.get_tensor(f"{prefix}.qzeros") qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
@ -247,14 +215,7 @@ class GPTQWeightsLoader(WeightsLoader):
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1 [weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
) )
from text_generation_server.layers.gptq import HAS_EXLLAMA use_exllama = self.bits == 4 and self.quantize == "gptq" and not self.desc_act
use_exllama = (
self.bits == 4
and HAS_EXLLAMA
and self.quantize == "gptq"
and not self.desc_act
)
if self.quantize == "gptq" and self.quant_method == "gptq": if self.quantize == "gptq" and self.quant_method == "gptq":
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes] w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
@ -298,6 +259,7 @@ class GPTQWeightsLoader(WeightsLoader):
self._get_gptq_params(weights) self._get_gptq_params(weights)
use_exllama = True use_exllama = True
desc_act = self.desc_act
if self.bits != 4: if self.bits != 4:
use_exllama = False use_exllama = False
@ -321,7 +283,8 @@ class GPTQWeightsLoader(WeightsLoader):
if g_idx is not None: if g_idx is not None:
if ( if (
not torch.equal( not torch.equal(
g_idx.cpu(), # Remove g_idx[0] to adapt the check with TP>1.
(g_idx - g_idx[0]).cpu(),
torch.tensor( torch.tensor(
[i // self.groupsize for i in range(g_idx.shape[0])], [i // self.groupsize for i in range(g_idx.shape[0])],
dtype=torch.int32, dtype=torch.int32,
@ -332,34 +295,22 @@ class GPTQWeightsLoader(WeightsLoader):
# Exllama implementation does not support row tensor parallelism with act-order, as # Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
desc_act = True
from text_generation_server.layers.gptq import ( from text_generation_server.layers.gptq import (
CAN_EXLLAMA,
HAS_EXLLAMA,
GPTQWeight, GPTQWeight,
) )
if use_exllama: if not desc_act and self.groupsize != -1:
if not HAS_EXLLAMA:
if CAN_EXLLAMA:
log_once(
logger.warning,
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
)
use_exllama = False
else:
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
if use_exllama and self.groupsize != -1:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0) qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
scales = weights.get_sharded(f"{prefix}.scales", dim=0) scales = weights.get_sharded(f"{prefix}.scales", dim=0)
if g_idx is not None:
# qzeros, scales sharded, and g_idx must be adjusted accordingly
g_idx = g_idx - g_idx[0]
else: else:
qzeros = weights.get_tensor(f"{prefix}.qzeros") qzeros = weights.get_tensor(f"{prefix}.qzeros")
scales = weights.get_tensor(f"{prefix}.scales") scales = weights.get_tensor(f"{prefix}.scales")
if use_exllama and g_idx is not None:
g_idx = g_idx - g_idx[0]
if self.quantize == "gptq" and self.quant_method == "awq": if self.quantize == "gptq" and self.quant_method == "awq":
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
@ -392,7 +343,7 @@ class GPTQWeightsLoader(WeightsLoader):
) )
def _get_gptq_params(self, weights: Weights): def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"): if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item() self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item() self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False self.desc_act = False
@ -400,41 +351,7 @@ class GPTQWeightsLoader(WeightsLoader):
# before the `gptq_sym` setting tensor was added. # before the `gptq_sym` setting tensor was added.
self.sym = ( self.sym = (
weights.get_tensor("gptq_sym").item() weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym") if weights.has_tensor("gptq_sym")
else False else False
) )
self.quant_method = "gptq" self.quant_method = "gptq"
# Needs to be at the end because circular import.
try:
major, _minor = torch.cuda.get_device_capability()
except Exception:
major = 1
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
if V2:
from text_generation_server.layers.gptq.exllamav2 import (
QuantLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "2"
else:
from text_generation_server.layers.gptq.exllama import (
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
create_exllama_buffers, # noqa: F401
set_device, # noqa: F401
)
HAS_EXLLAMA = "1"
except ImportError:
pass

View File

@ -1,261 +0,0 @@
# https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
import builtins
import math
import time
from typing import Dict
import triton
class Autotuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
"""
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
"""
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = (
prune_configs_by["perf_model"],
prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
)
except triton.OutOfResources:
return [float("inf"), float("inf"), float("inf")]
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {
config: self._bench(*args, config=config, **kwargs)
for config in pruned_configs
}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
:top_k
]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
],
key=['x_size'] # the two above configs will be evaluated anytime
# the value of x_size changes
)
@triton.jit
def kernel(x_ptr, x_size, **META):
BLOCK_SIZE = META['BLOCK_SIZE']
:note: When all the configurations are evaluated, the kernel will run multiple time.
This means that whatever value the kernel updates will be updated multiple times.
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
reset the value of the provided tensor to `zero` before running any configuration.
:param configs: a list of :code:`triton.Config` objects
:type configs: list[triton.Config]
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
:type key: list[str]
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
'perf_model': performance model used to predicate running time with different configs, returns running time
'top_k': number of configs to bench
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
:type reset_to_zero: list[str]
"""
def decorator(fn):
return Autotuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]
if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue
used.add(
(
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
)
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)

View File

@ -1,134 +0,0 @@
from text_generation_server.layers.gptq import GPTQWeight
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
assert DEVICE is not None, "call set_device first"
if not ACT_ORDER:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros(
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear(torch.nn.Module):
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, weight: GPTQWeight, bias):
super().__init__()
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert weight.bits == 4
self.device = weight.qweight.device
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and (
(self.g_idx == 0).all()
or torch.equal(
weight.g_idx.cpu(),
torch.tensor(
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
dtype=torch.int32,
),
)
):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
)
self.height = weight.qweight.shape[0] * 8
self.width = weight.qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert weight.groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None:
raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out

View File

@ -1,267 +0,0 @@
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.layers.gptq import GPTQWeight
from text_generation_server.utils.log import log_master
try:
from exllamav2.ext import exllamav2_ext
make_q_matrix = exllamav2_ext.make_q_matrix
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
except ImportError:
log_master(logger.warning, "exllamav2_kernels not installed.")
raise
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
@dataclass
class _ExtraTensors:
"""Additional generated quantizer tensors."""
q_group_map: Optional[torch.Tensor] = None
q_invperm: Optional[torch.Tensor] = None
q_perm: Optional[torch.Tensor] = None
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
"""Matrix multiplication, returns x @ q4"""
output_shape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
gemm_half_q_half(x, q_handle, output, force_cuda)
return output.view(output_shape)
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
gr = q_groups.tolist()
group_map = []
num_groups = len(gr) // 2
for i in range(num_groups):
bits = gr[i * 2]
if i < num_groups - 1:
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
else:
qrows = num_qrows - gr[i * 2 + 1]
rows = qrows * 32 // bits
for j in range(rows):
group_map += [i]
group_map += [rows - j]
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
# Create Q matrix
def ext_make_q_matrix(
w: Exl2Weight | GPTQWeight,
extra: _ExtraTensors,
temp_dq,
key: Optional[str] = None,
):
"""
Create Q matrix
"""
# max_dq_size = 512*(1024**2)
# max_dq_rows = max_dq_size // out_features[0]
max_dq_rows = 0
# EXL2
if isinstance(w, Exl2Weight):
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
extra.q_perm = torch.argsort(w.q_invperm).short()
return make_q_matrix(
w.q_weight,
extra.q_perm,
w.q_invperm,
w.q_scale,
w.q_scale_max,
w.q_groups,
extra.q_group_map,
none_tensor, # zeros
none_tensor, # scales
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ
elif isinstance(w, GPTQWeight):
if w.scales.dtype == torch.float:
w.scales = w.scales.half()
# GPTQ with g_idx (act_order)
if w.g_idx is not None and not (w.g_idx == 0).all().item():
extra.q_perm = torch.empty(
(w.qweight.shape[0] * 8,),
dtype=torch.short,
device=w.qweight.device,
)
extra.q_invperm = torch.empty_like(extra.q_perm)
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
return make_q_matrix(
w.qweight,
extra.q_perm,
extra.q_invperm,
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
w.g_idx.cpu(),
none_tensor, # bias
temp_dq,
max_dq_rows,
)
# GPTQ without g_idx
else:
return make_q_matrix(
w.qweight,
none_tensor, # q_perm
none_tensor, # q_invperm
none_tensor, # q_scale
none_tensor, # q_scale_max
none_tensor, # q_groups
none_tensor, # q_group_map
w.qzeros,
w.scales,
none_tensor, # g_idx
none_tensor, # bias
temp_dq,
max_dq_rows,
)
else:
RuntimeError("Cannot create handle")
DEVICE = None
LAYERS = []
def set_device(device):
global DEVICE
DEVICE = device
def create_exllama_buffers(max_total_tokens: int):
global LAYERS, DEVICE
# No need to initialize scratch space if there are no layers
# that use ExLLamav2.
if len(LAYERS) == 0:
return
# Find the size of the scratch space.
scratch_bytes = max(
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
for layer in LAYERS
)
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
for layer in LAYERS:
layer.post_init(temp_dq)
class QuantLinear(nn.Module):
QUANT_TYPE = "exllamav2"
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(
self,
weight: Exl2Weight | GPTQWeight,
bias: torch.Tensor,
):
super().__init__()
self.q_handle = None
self.q_tensors = weight
self.extra_tensors = _ExtraTensors()
if isinstance(weight, Exl2Weight):
self.infeatures = weight.q_invperm.shape[0]
self.outfeatures = weight.q_weight.shape[1]
elif isinstance(weight, GPTQWeight):
if weight.bits != 4:
raise ValueError(
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
)
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
self.outfeatures = weight.qweight.shape[1]
self.padding = -self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding
self.device = weight.device
self.bias = bias if bias is not None else None
global LAYERS
LAYERS.append(self)
def post_init(self, temp_dq):
device = self.q_tensors.device
assert device.type == "cuda"
assert device.index is not None
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
# and `Memory access fault by GPU node-2` will EAT you.
self.temp_dq = temp_dq
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
def forward(self, x, force_cuda=False):
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
if self.bias is not None:
output.add_(self.bias)
return output
def temp_dq_size(self):
return self.infeatures * self.outfeatures * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len, max_batch_size):
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
class ExLlamaV2DeviceTensors:
device_idx: int
scratch_bytes: int
scratch_idx: int
scratch: torch.tensor = None
def __init__(self, device, scratch_bytes):
self.device = device
self.scratch_bytes = scratch_bytes
def prepare(self):
self.scratch = torch.empty(
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
)
def get_scratch_slice(self, size_bytes):
if self.scratch is None:
self.prepare()
size_bytes = ((size_bytes + 127) // 128) * 128
size_half = size_bytes // 2
scratch_slice = self.scratch.narrow(0, 0, size_half)
return scratch_slice

View File

@ -0,0 +1,186 @@
import math
import numpy as np
import torch
import torch.nn as nn
try:
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
def pack_tensor(input, bits=4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
self.wf = torch.tensor(
list(range(0, 32, self.bits)), dtype=torch.int32
).unsqueeze(0)
self._preprocessing()
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**self.bits) - 1).to(
self.scales.dtype
) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0] * weight.shape[1], weight.shape[2]))
return weight
def _preprocessing(self):
orig_device = self.qweight.device
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to(orig_device)
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.groupsize for i in range(columns)]
g_idx_trivial = torch.tensor(
g_idx_trivial, dtype=torch.int32, device=self.g_idx.device
)
assert torch.equal(
self.g_idx, g_idx_trivial
), "Non-trivial tensor g_idx is not supported"
self.qzeros = self.qzeros.cpu()
zeros = self.unpack_zeros_from_cuda_old_format()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to(orig_device)
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [4]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 4 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
weight = convert_from_uint4(self.qweight, self.scales, self.qzeros, x.dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out

View File

@ -1,359 +0,0 @@
import math
import numpy as np
import torch
import torch.nn as nn
from torch.cuda.amp import custom_fwd
import triton
import triton.language as tl
from . import custom_autotune
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=4,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk
+ offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(
scales_ptrs + g_idx[:, None] * stride_scales
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(
zeros_ptrs + g_idx[:, None] * stride_zeros
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = (zeros + 1) & maxq # eventually avoid overflow
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty(
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
)
def grid(META):
return (
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output
class QuantLinear(nn.Module):
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
super().__init__()
self.register_buffer("qweight", qweight)
self.register_buffer("qzeros", qzeros)
self.register_buffer("scales", scales)
self.register_buffer("g_idx", g_idx)
if bias is not None:
self.register_buffer("bias", bias)
else:
self.bias = None
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize
self.outfeatures = qweight.shape[1]
self.infeatures = qweight.shape[0] * 32 // bits
@classmethod
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
qzeros = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
dtype=torch.int32,
)
scales = torch.zeros(
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
)
g_idx = torch.tensor(
[i // groupsize for i in range(infeatures)], dtype=torch.int32
)
if bias:
bias = torch.zeros((outfeatures), dtype=torch.float16)
else:
bias = None
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
def pack(self, linear, scales, zeros, g_idx=None):
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
i = 0
row = 0
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -12,7 +12,7 @@ from huggingface_hub import HfApi
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils import initialize_torch_distributed, Weights from text_generation_server.utils import initialize_torch_distributed, Weights
from text_generation_server.utils.hub import weight_files from text_generation_server.utils.hub import weight_files
from text_generation_server.layers.gptq.quant_linear import QuantLinear from text_generation_server.layers.gptq import QuantLinear
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from text_generation_server.layers.gptq.utils import torch_snr_error from text_generation_server.layers.gptq.utils import torch_snr_error
@ -956,15 +956,24 @@ def quantize(
pack(model, quantizers, bits, groupsize) pack(model, quantizers, bits, groupsize)
from safetensors.torch import save_file from safetensors.torch import save_file
from transformers.modeling_utils import shard_checkpoint from huggingface_hub import split_torch_state_dict_into_shards
state_dict = model.state_dict() state_dict = model.state_dict()
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()} state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
max_shard_size = "10GB" max_shard_size = "10GB"
shards, index = shard_checkpoint( state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors" state_dict,
filename_pattern="model.safetensors",
max_shard_size=max_shard_size,
) )
index = None
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
shards = state_dict_split.filename_to_tensors
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
save_file( save_file(

View File

@ -1,9 +1,6 @@
import torch import torch
from torch import nn from torch import nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import (
SYSTEM,
)
# Monkey patching # Monkey patching
@ -33,69 +30,14 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
torch.nn.LayerNorm.load = load_layer_norm torch.nn.LayerNorm.load = load_layer_norm
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
if SYSTEM == "cuda":
import dropout_layer_norm
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if residual is not None:
if residual is not None: hidden_states += residual
hidden_states += residual residual = hidden_states
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual return super().forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
elif SYSTEM == "rocm":
from vllm._C import ops
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states
return super().forward(hidden_states), residual
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
out = ipex.llm.functional.add_layer_norm(
residual,
hidden_states,
self.weight,
self.bias,
self.eps,
residual is not None,
)
return out, residual if residual is not None else hidden_states
class FastRMSNorm(nn.Module): class FastRMSNorm(nn.Module):
@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if SYSTEM == "ipex": from vllm_hpu_extension.kernels import rms_norm
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
self.weight,
None,
self.variance_epsilon,
residual is not None,
)
return out, residual if residual is not None else hidden_states
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32) orig_shape = hidden_states.shape
variance = hidden_states.pow(2).mean(-1, keepdim=True) if residual is not None:
hidden_states = hidden_states * torch.rsqrt( residual += hidden_states.view(residual.shape)
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, residual
elif SYSTEM == "cuda":
# faster post attention rms norm
(
normed_hidden_states,
res,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
return normed_hidden_states, res
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
return out, residual
else: else:
raise ValueError( residual = hidden_states
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." # Note: HPUFusedRMSNorm requires 3D tensors as inputs
) if len(orig_shape) == 2:
residual = residual.unsqueeze(0)
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)

View File

@ -1,21 +1,5 @@
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM
from torch.nn import functional as F from torch.nn import functional as F
import os
if SYSTEM == "rocm":
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
"true",
"1",
)
if ROCM_USE_SKINNY_GEMM:
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
)
class FastLinear(torch.nn.Module): class FastLinear(torch.nn.Module):
@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module):
return F.linear(input, self.weight, self.bias) return F.linear(input, self.weight, self.bias)
class FastLinearROCm(torch.nn.Module):
def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(weight)
if bias is not None:
self.bias = torch.nn.Parameter(bias)
else:
self.bias = None
self.cu_count = torch.cuda.get_device_properties(
device="cuda"
).multi_processor_count
self.use_skinny_gemm = (
ROCM_USE_SKINNY_GEMM
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
)
@classmethod
def load(cls, config, prefix: str, weights, bias: bool):
weight = weights.get_tensor(f"{prefix}.weight")
if bias:
bias = weights.get_tensor(f"{prefix}.bias")
else:
bias = None
return cls(weight, bias)
def forward(self, inp: torch.Tensor) -> torch.Tensor:
weight = self.weight
bias = self.bias
if (
self.use_skinny_gemm
and inp.dtype == torch.float16
and inp.shape[-1] % 8 == 0
):
batched = False
inp_shape = inp.shape
if inp.dim() == 3:
inp = inp.view(-1, inp_shape[-1])
batched = True
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
if m > 8 and n <= 4:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
elif m % 4 == 0 and n == 1 and k <= 8192:
out = torch.empty(
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
)
_custom_C.LLMM1(weight, inp, out, 4)
else:
out = F.linear(inp, weight)
if batched:
out.view(*inp_shape[:-1], out.shape[-1])
if bias is not None:
out = out + bias
return out
return F.linear(inp, self.weight, self.bias)
def get_linear(weight, bias): def get_linear(weight, bias):
# Weights that are loaded through methods that are not # Weights that are loaded through methods that are not
# quantization-aware are still bare tensors. We may want # quantization-aware are still bare tensors. We may want
# to change this in the future. # to change this in the future.
if isinstance(weight, torch.Tensor): if isinstance(weight, torch.Tensor):
if SYSTEM == "rocm": return FastLinear(weight, bias)
return FastLinearROCm(weight, bias)
else:
return FastLinear(weight, bias)
return weight.get_linear(bias) return weight.get_linear(bias)

View File

@ -1,15 +0,0 @@
from text_generation_server.layers.marlin.fp8 import GPTQMarlinFP8Linear
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeightsLoader,
can_use_gptq_marlin,
repack_gptq_for_marlin,
)
from text_generation_server.layers.marlin.marlin import MarlinWeightsLoader
__all__ = [
"GPTQMarlinFP8Linear",
"GPTQMarlinWeightsLoader",
"MarlinWeightsLoader",
"can_use_gptq_marlin",
"repack_gptq_for_marlin",
]

View File

@ -1,140 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.layers.marlin.gptq import _check_valid_shape
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
permute_scales,
)
from text_generation_server.utils.log import log_once
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
MARLIN_TILE_SIZE = 16
class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""
def __init__(
self,
qweight: torch.Tensor,
scales: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scales = scales.unsqueeze(0)
if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)
@classmethod
def from_unquant(cls, weight, bias, dtype):
qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)
for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)
return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""
out_features, in_features = weight.shape
# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)
scales = permute_scales(scales)
return repacked, scales

View File

@ -1,464 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy
import torch
import torch.nn as nn
from loguru import logger
from text_generation_server.layers.marlin.util import (
_check_marlin_kernels,
marlin_zero_points,
permute_scales,
unpack_cols,
)
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
GPTQ_MARLIN_BITS = [4, 8]
GPTQ_MARLIN_GROUP_SIZES = [-1, 32, 64, 128]
MARLIN_TILE_SIZE = 16
def can_use_gptq_marlin(
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
) -> bool:
return (
SYSTEM == "cuda"
and marlin_kernels is not None
and has_sm_8_0
and quantize in {"awq", "gptq"}
and quant_method in {"awq", "gptq"}
and bits in GPTQ_MARLIN_BITS
and groupsize in GPTQ_MARLIN_GROUP_SIZES
# We only suppord asymmetric quantization for AWQ.
and (sym or quant_method == "awq")
)
class GPTQMarlinWeightsLoader(WeightsLoader):
"""
Loader for using GPTQ- and AWQ-quantized weights with Marlin kernels.
"""
def __init__(
self,
*,
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
quantize: str,
sym: bool,
):
self.bits = bits
self.desc_act = desc_act
self.groupsize = groupsize
self.quant_method = quant_method
self.quantize = quantize
self.sym = sym
def get_weights(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_tensor(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
scales = weights.get_tensor(f"{prefix}.scales")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
try:
qweight = weights.get_packed_sharded(
f"{prefix}.qweight", dim=1, block_sizes=block_sizes
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
)
scales = weights.get_packed_sharded(
f"{prefix}.scales", dim=1, block_sizes=block_sizes
)
scales = scales.to(dtype=weights.dtype)
if not self.sym:
qzeros = weights.get_packed_sharded(
f"{prefix}.qzeros", dim=1, block_sizes=block_sizes
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_tensor(f"{prefix}.g_idx")
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
try:
qweight = torch.cat(
[weights.get_sharded(f"{p}.qweight", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized"
)
scales = torch.cat(
[weights.get_sharded(f"{p}.scales", dim=1) for p in prefixes], dim=1
)
if not self.sym:
qzeros = torch.cat(
[weights.get_sharded(f"{p}.qzeros", dim=1) for p in prefixes], dim=1
)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=False,
)
def get_weights_row(self, weights: Weights, prefix: str):
log_once(logger.info, "Using GPTQ-Marlin kernels")
try:
qweight = weights.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{self.quantize}` weight for GPTQ -> Marlin repacking, make sure the model is already quantized"
)
if not self.sym:
if self.desc_act or self.groupsize == -1:
qzeros = weights.get_tensor(f"{prefix}.qzeros")
else:
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
else:
qzeros = None
if self.quant_method == "awq":
g_idx = None
else:
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
if self.desc_act or self.groupsize == -1:
scales = weights.get_tensor(f"{prefix}.scales")
else:
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
sharded_in_features = weights.process_group.size() > 1
return repack_gptq_for_marlin(
qweight=qweight,
scales=scales,
qzeros=qzeros,
g_idx=g_idx,
bits=self.bits,
desc_act=self.desc_act,
groupsize=self.groupsize,
quant_method=self.quant_method,
sym=self.sym,
sharded_infeatures=sharded_in_features,
)
def _get_gptq_params(self, weights: Weights):
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
self.bits = weights.get_tensor("gptq_bits").item()
self.groupsize = weights.get_tensor("gptq_groupsize").item()
self.desc_act = False
# `server quantize` used asymmetric quantization unconditionally
# before the `gptq_sym` setting tensor was added.
self.sym = (
weights.get_tensor("gptq_sym").item()
if weights._has_tensor("gptq_sym")
else False
)
self.quant_method = "gptq"
@dataclass
class GPTQMarlinWeight(Weight):
"""
Repacked GPTQ Marlin weights.
"""
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
bits: int
is_full_k: bool
def __post_init__(self):
assert self.qweight.dtype == torch.int32
assert self.scales.dtype == torch.float16
assert self.g_idx.dtype == torch.int32
assert self.perm.dtype == torch.int32
def get_linear(self, bias: torch.Tensor):
return GPTQMarlinLinear(
weight=self,
bias=bias,
)
def repack_gptq_for_marlin(
*,
qweight: torch.Tensor,
qzeros: Optional[torch.Tensor],
scales: torch.Tensor,
g_idx: Optional[torch.Tensor],
bits: int,
desc_act: bool,
groupsize: int,
quant_method: str,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
_check_marlin_kernels()
assert marlin_kernels is not None
if bits not in GPTQ_MARLIN_BITS:
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
raise RuntimeError(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not (sym or quant_method == "awq"):
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
log_once(logger.info, f"Converting {quant_method} model to Marlin packing format.")
weights_per_int = 32 // bits
in_features = qweight.shape[0]
out_features = qweight.shape[1]
# AWQ uses column packing, GPTQ uses row packing
if quant_method == "awq":
out_features *= weights_per_int
else:
in_features *= weights_per_int
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)
if g_idx is not None and desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
if quant_method == "awq":
repacked = marlin_kernels.awq_marlin_repack(
qweight, in_features, out_features, bits
)
if qzeros is not None:
qzeros = awq_to_marlin_zero_points(
qzeros,
in_features // groupsize,
out_features,
bits,
)
else:
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, bits
)
if qzeros is None:
qzeros = torch.empty(0, dtype=torch.int, device=qweight.device)
scales = permute_scales(scales)
is_full_k = not (desc_act and groupsize != -1 and sharded_infeatures)
return GPTQMarlinWeight(
qweight=repacked,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
bits=bits,
is_full_k=is_full_k,
)
class GPTQMarlinLinear(nn.Module):
"""
Linear layer for GPTQ weights that were converted for the GPTQ-Marlin
kernels.
"""
def __init__(
self,
*,
weight: GPTQMarlinWeight,
bias: Optional[torch.Tensor],
):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
out_features = weight.scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)
self.bits = weight.bits
self.is_full_k = weight.is_full_k
self.qweight = weight.qweight
self.qzeros = weight.qzeros
self.scales = weight.scales
self.g_idx = weight.g_idx
self.perm = weight.perm
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.qweight.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.gptq_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.perm,
self.workspace,
self.bits,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
self.is_full_k,
self.qzeros.numel() > 0,
True,
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))
if self.bias is not None:
C += self.bias
return C
def awq_to_marlin_zero_points(
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
def _check_valid_shape(in_features: int, out_features: int):
if (in_features % 128 != 0 or out_features % 64 != 0) and (
in_features % 64 != 0 or out_features % 128 != 0
):
raise ValueError(
f"The GPTQ Marlin kernel does not have a valid thread configuration for weight matrix with shape ({out_features}, {in_features})."
" The shape elements must be divisible by (128, 64) or (64, 128)."
)

View File

@ -1,346 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.nn as nn
from text_generation_server.layers.marlin.util import _check_marlin_kernels
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
class MarlinWeightsLoader(WeightsLoader):
"""Loader for Marlin-quantized weights."""
def __init__(self, *, bits: int, is_marlin_24: bool):
self.bits = bits
self.is_marlin_24 = is_marlin_24
def get_weights(self, weights: "Weights", prefix: str):
"""
Get weights at the given prefix and apply without tensor paralllism.
"""
is_marlin_24 = getattr(self, "gptq_checkpoint_format", None) == "marlin_24"
if is_marlin_24:
try:
B = weights.get_tensor(f"{prefix}.B_24")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_tensor(f"{prefix}.B_meta")
s = weights.get_tensor(f"{prefix}.s")
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_tensor(f"{prefix}.B")
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
s = weights.get_tensor(f"{prefix}.s")
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_col_packed(
self,
weights: Weights,
prefix: str,
block_sizes: Union[int, List[int]],
):
if self.is_marlin_24:
B = weights.get_packed_sharded(
f"{prefix}.B_24", dim=1, block_sizes=block_sizes
)
B_meta = weights.get_packed_sharded(
f"{prefix}.B_meta", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
B = weights.get_packed_sharded(
f"{prefix}.B", dim=1, block_sizes=block_sizes
)
s = weights.get_packed_sharded(
f"{prefix}.s", dim=1, block_sizes=block_sizes
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
if self.is_marlin_24:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B_24", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
B_meta = torch.cat(
[weights.get_sharded(f"{p}.B_meta", dim=1) for p in prefixes], dim=1
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = torch.cat(
[weights.get_sharded(f"{p}.B", dim=1) for p in prefixes], dim=1
)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized"
)
s = torch.cat(
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
)
weight = MarlinWeight(B=B, s=s)
return weight
def get_weights_row(self, weights: Weights, prefix: str):
if self.is_marlin_24:
try:
B = weights.get_sharded(f"{prefix}.B_24", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` 2:4 sparsity weight, make sure the model is already quantized."
)
B_meta = weights.get_sharded(f"{prefix}.B_meta", dim=0)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
else:
try:
B = weights.get_sharded(f"{prefix}.B", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `marlin` weight, make sure the model is already quantized."
)
num_groups = weights._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = weights.get_tensor(f"{prefix}.s")
else:
s = weights.get_sharded(f"{prefix}.s", dim=0)
weight = MarlinWeight(B=B, s=s)
return weight
@dataclass
class MarlinWeight(Weight):
"""
Marlin weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): bfloat16/float16 scales.
"""
B: torch.Tensor
s: torch.Tensor
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias)
class MarlinLinear(nn.Module):
def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
out_features = weight.s.shape[1]
assert (
in_features % 128 == 0
), f"Number of input features ({in_features}) not divisable by 128"
assert (
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {groupsize}"
self.B = weight.B
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=weight.B.device
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.marlin_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.s,
self.workspace,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
MARLIN_TILE_SIZE = 16
@dataclass
class GPTQMarlin24Weight:
"""
GPTQ-Marlin 2:4 weights.
Attributes:
B (torch.Tensor): int4-quantized weights packed into int32.
B_meta (torch.Tensor): metadata for 2:4 sparsity.
s (torch.Tensor): float16 scales.
bits: quantized weight size.
"""
B: torch.Tensor
B_meta: torch.Tensor
s: torch.Tensor
bits: int
def __post_init__(self):
assert self.B.dtype == torch.int32
assert self.B_meta.dtype == torch.int16
assert self.s.dtype == torch.float16
def get_linear(self, bias: torch.Tensor):
return GPTQMarlin24Linear(
weight=self,
bias=bias,
)
class GPTQMarlin24Linear(nn.Module):
def __init__(self, *, weight: GPTQMarlin24Weight, bias: Optional[torch.Tensor]):
super().__init__()
_check_marlin_kernels()
assert marlin_kernels is not None
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
supported_bits = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
)
raise RuntimeError(
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
)
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
out_features = weight.s.shape[1]
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
supported_sizes = ", ".join(
str(b) for b in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
)
raise RuntimeError(
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
)
self.bits = weight.bits
weights_per_int32 = 32 // self.bits
assert (
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_N} threads"
assert (
out_features % weights_per_int32 == 0
), f"Number of output features ({out_features}) not divisable by weights per int32 ({weights_per_int32})"
assert (
in_features % GPTQ_MARLIN_24_MIN_THREAD_K == 0
), f"Number of output features ({out_features}) not divisable by {GPTQ_MARLIN_24_MIN_THREAD_K} threads"
if groupsize != -1 and in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
)
self.B = weight.B
self.B_meta = weight.B_meta
self.s = weight.s
if bias is not None:
self.bias = bias
else:
self.bias = None
self.workspace = torch.zeros(
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
dtype=torch.int,
device=weight.B.device,
)
def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None
C = marlin_kernels.gptq_marlin_24_gemm(
A.view(-1, A.shape[-1]),
self.B,
self.B_meta,
self.s,
self.workspace,
self.bits,
A.shape[0],
self.s.shape[1],
A.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
if self.bias is not None:
C += self.bias
return C

View File

@ -1,141 +0,0 @@
import functools
from typing import List, Tuple
import numpy
import torch
from text_generation_server.utils.import_utils import SYSTEM
try:
import marlin_kernels
except ImportError:
marlin_kernels = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def _check_marlin_kernels():
if not (SYSTEM == "cuda" and has_sm_8_0):
raise NotImplementedError(
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
)
if marlin_kernels is None:
raise NotImplementedError(
"marlin is not installed, install it with: pip install server/marlin"
)
# https://github.com/IST-DASLab/marlin/blob/2f6d7c10e124b3c5fa29ff8d77d568bd7af3274c/marlin/__init__.py#L40C1-L68C54
@functools.cache
def get_perms() -> Tuple[List[int], List[int]]:
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def permute_scales(scales: torch.Tensor):
scale_perm, scale_perm_single = get_perms()
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
else:
scales = scales.reshape((-1, len(scale_perm)))[:, scale_perm]
return scales.reshape((-1, out_features)).contiguous()
# Functions below are from vLLM
def get_pack_factor(bits: int) -> int:
if 32 % bits != 0:
raise ValueError(f"Cannot {bits} bit values into uint32")
return 32 // bits
def pack_cols(
q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
assert q_w.shape == (size_k, size_n)
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
for i in range(pack_factor):
q_res |= q_w[:, i::pack_factor] << num_bits * i
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def unpack_cols(
packed_q_w: torch.Tensor,
num_bits: int,
size_k: int,
size_n: int,
):
pack_factor = get_pack_factor(num_bits)
assert size_n % pack_factor == 0
assert packed_q_w.shape == (
size_k,
size_n // pack_factor,
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
packed_q_w.shape, size_k, size_n, pack_factor
)
orig_device = packed_q_w.device
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
mask = (1 << num_bits) - 1
for i in range(pack_factor):
vals = packed_q_w_cpu & mask
packed_q_w_cpu >>= num_bits
q_res[:, i::pack_factor] = vals
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
q_res = q_res.contiguous()
return q_res
def marlin_zero_points(
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
scale_perm, _ = get_perms()
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp

View File

@ -10,13 +10,8 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
from text_generation_server.layers.moe.gptq_marlin import (
GPTQMarlinSparseMoELayer,
can_use_marlin_moe_gemm,
)
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
from text_generation_server.utils.log import log_once from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader, DefaultWeightsLoader,
@ -24,12 +19,7 @@ from text_generation_server.utils.weights import (
UnquantizedWeight, UnquantizedWeight,
) )
if SYSTEM == "rocm": from .fused_moe import fused_topk, grouped_topk
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_topk, grouped_topk
# NOTE: we are using a protocol here, because multiple inherance is not nice. # NOTE: we are using a protocol here, because multiple inherance is not nice.
# We need `Module`, and `Module` -> some abstract class -> some concrete # We need `Module`, and `Module` -> some abstract class -> some concrete
@ -52,6 +42,8 @@ class MoELayer(Protocol):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ... ): ...
def forward( def forward(
@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module):
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
hidden_act: str = "silu", hidden_act: str = "silu",
scoring_func: Optional[str] = None,
e_score_correction_bias: Optional[float] = None,
): ):
super().__init__() super().__init__()
assert scoring_func is None, "scoring func is not handled"
assert e_score_correction_bias is None, "scoring correction bias is not handled"
log_once( log_once(
logger.info, logger.info,
"No fused layers are available for this model type, using (slower) dense MoE layer", "No fused layers are available for this model type, using (slower) dense MoE layer",
@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module):
topk: int, topk: int,
topk_group: Optional[int], topk_group: Optional[int],
weights: Weights, weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
): ):
super().__init__() super().__init__()
if ( if (
isinstance(weights.loader, DefaultWeightsLoader) isinstance(weights.loader, DefaultWeightsLoader)
and isinstance(weights.loader.weight_class, UnquantizedWeight) and isinstance(weights.loader.weight_class, UnquantizedWeight)
) or isinstance(weights.loader, HybridFP8UnquantLoader): ) or isinstance(weights.loader, HybridFP8UnquantLoader):
cls = UnquantizedSparseMoELayer if (
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym: isinstance(weights.loader, HybridFP8UnquantLoader)
cls = GPTQMarlinSparseMoELayer and weights.loader.to_fp8
):
cls = FP8SparseMoELayer
else:
cls = UnquantizedSparseMoELayer
else: else:
raise ValueError( raise ValueError(
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights" f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
) )
log_once( log_once(
@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module):
topk=topk, topk=topk,
topk_group=topk_group, topk_group=topk_group,
weights=weights, weights=weights,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
gate_proj_name=gate_proj_name, gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name, up_proj_name=up_proj_name,
down_proj_name=down_proj_name, down_proj_name=down_proj_name,
@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module):
@staticmethod @staticmethod
def is_supported(weights: Weights) -> bool: def is_supported(weights: Weights) -> bool:
return ( return (
( isinstance(weights.loader, DefaultWeightsLoader)
isinstance(weights.loader, DefaultWeightsLoader) and isinstance(weights.loader.weight_class, UnquantizedWeight)
and isinstance(weights.loader.weight_class, UnquantizedWeight) ) or isinstance(weights.loader, HybridFP8UnquantLoader)
)
or isinstance(weights.loader, HybridFP8UnquantLoader)
or (
isinstance(weights.loader, GPTQMarlinWeightsLoader)
and can_use_marlin_moe_gemm(
quant_method=weights.loader.quant_method,
quantize=weights.loader.quantize,
sym=weights.loader.sym,
)
)
)

View File

@ -0,0 +1,173 @@
from typing import Optional
import torch
import torch.nn as nn
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
Fp8Weight,
fp8_quantize,
quant_dtype,
normalize_e4m3fn_to_native_float8,
)
try:
from .unquantized import fused_moe
except Exception:
fused_moe = None
class FP8SparseMoELayer(nn.Module):
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
(
self.gate_up_proj,
self.gate_up_proj_weight_scale,
self.gate_up_proj_input_scale,
) = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
gate_proj_name=gate_proj_name,
up_proj_name=up_proj_name,
weights=weights,
)
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
_load_expert_weights_row(
prefix=prefix,
n_experts=n_experts,
name=down_proj_name,
weights=weights,
)
)
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_moe(
x,
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
use_fp8_w8a8=True,
w1_scale=self.gate_up_proj_weight_scale,
w2_scale=self.down_proj_weight_scale,
a1_scale=self.gate_up_proj_input_scale,
a2_scale=self.down_proj_input_scale,
)
def _load_expert_weights(
get_weight_fn,
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
all_weight = None
all_weight_scales = None
max_input_scale = None
for i in range(n_experts):
weight = get_weight_fn(prefix, i, name, weights)
assert isinstance(weight, Fp8Weight)
if all_weight is None:
all_weight = torch.empty(
(n_experts,) + weight.weight.shape,
dtype=quant_dtype,
device=weight.weight.device,
)
if all_weight_scales is None:
all_weight_scales = torch.empty(
(n_experts,) + weight.weight_scale.shape,
dtype=torch.float32,
device=weight.weight.device,
)
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
all_weight[i], all_weight_scales[i], current_input_scale = (
normalize_e4m3fn_to_native_float8(
weight.weight, weight.weight_scale, weight.input_scale
)
)
if current_input_scale is not None:
if max_input_scale is None or current_input_scale > max_input_scale:
max_input_scale = current_input_scale
else:
all_weight[i], all_weight_scales[i] = fp8_quantize(
weight.weight, scalar=True
)
assert all_weight is not None
return all_weight, all_weight_scales, max_input_scale
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
gate_proj_name: str,
up_proj_name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_multi_weights_col(
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
)
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
)
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> torch.Tensor:
def get_weight_fn(prefix, i, name, weights):
return weights.get_weights_row(f"{prefix}.{i}.{name}")
return _load_expert_weights(
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
)

View File

@ -16,10 +16,8 @@
from typing import Tuple from typing import Tuple
import torch import torch
import torch.distributed
# TODO: Remove the functions once moe_kernel are built for ROCM
def grouped_topk( def grouped_topk(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
@ -50,3 +48,18 @@ def grouped_topk(
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
topk_weights = torch.nn.functional.softmax(
gating_output, dim=1, dtype=torch.float32
)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids

View File

@ -1,215 +0,0 @@
from dataclasses import dataclass
from typing import List, Optional
import torch
import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights
from text_generation_server.layers.marlin.gptq import (
GPTQMarlinWeight,
GPTQMarlinWeightsLoader,
)
if SYSTEM == "cuda":
from moe_kernels.fused_marlin_moe import fused_marlin_moe
else:
fused_marlin_moe = None
try:
major, _minor = torch.cuda.get_device_capability()
has_sm_8_0 = major >= 8
except Exception:
has_sm_8_0 = False
def can_use_marlin_moe_gemm(
*,
quant_method: str,
quantize: str,
sym: bool,
):
return (
SYSTEM == "cuda"
and fused_marlin_moe is not None
and has_sm_8_0
and quantize == "gptq"
and quant_method == "gptq"
and sym
)
@dataclass
class GPTQMarlinMoEWeight:
qweight: torch.Tensor
qzeros: torch.Tensor
scales: torch.Tensor
g_idx: torch.Tensor
perm: torch.Tensor
is_full_k: bool
class GPTQMarlinSparseMoELayer(nn.Module):
"""
MoE layer that uses a fused GPTQ-Marlin kernel.
"""
def __init__(
self,
*,
n_expert_group: Optional[int],
n_experts: int,
prefix: str,
renormalize: bool,
topk: int,
topk_group: Optional[int],
weights: Weights,
gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj",
):
super().__init__()
if not (
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
):
raise ValueError(
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
)
assert (n_expert_group is None) == (
topk_group is None
), "n_expert_group and topk_group must both be None or have some value"
self.n_expert_group = n_expert_group
self.topk = topk
self.topk_group = topk_group
self.renormalize = renormalize
self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix,
n_experts=n_experts,
names=[gate_proj_name, up_proj_name],
weights=weights,
)
self.down_proj = _load_expert_weights_row(
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
)
self.bits = weights.loader.bits
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
return fused_marlin_moe(
x,
w1=self.gate_up_proj.qweight,
w2=self.down_proj.qweight,
g_idx1=self.gate_up_proj.g_idx,
g_idx2=self.down_proj.g_idx,
perm1=self.gate_up_proj.perm,
perm2=self.down_proj.perm,
w1_scale=self.gate_up_proj.scales,
w2_scale=self.down_proj.scales,
is_full_k1=self.gate_up_proj.is_full_k,
is_full_k2=self.down_proj.is_full_k,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
num_bits=self.bits,
)
def _load_expert_multi_weights_col(
*,
prefix: str,
n_experts: int,
names: List[str],
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_multi_weights_col(
[f"{prefix}.{i}.{name}" for name in names], 0
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _load_expert_weights_row(
*,
prefix: str,
n_experts: int,
name: str,
weights: Weights,
) -> GPTQMarlinMoEWeight:
moe_weight = None
for i in range(n_experts):
weight = weights.get_weights_row(
f"{prefix}.{i}.{name}",
)
assert isinstance(weight, GPTQMarlinWeight)
moe_weight = _pack_weight(
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
)
assert moe_weight is not None
return moe_weight
def _pack_weight(
*,
n_experts: int,
expert: int,
moe_weight: Optional[GPTQMarlinMoEWeight],
weight: GPTQMarlinWeight,
) -> GPTQMarlinMoEWeight:
if moe_weight is None:
qweight = torch.empty(
(n_experts,) + weight.qweight.shape,
dtype=weight.qweight.dtype,
device=weight.qweight.device,
)
qzeros = torch.empty(
(n_experts,) + weight.qzeros.shape,
dtype=weight.qzeros.dtype,
device=weight.qzeros.device,
)
scales = torch.empty(
(n_experts,) + weight.scales.shape,
dtype=weight.scales.dtype,
device=weight.scales.device,
)
g_idx = torch.empty(
(n_experts,) + weight.g_idx.shape,
dtype=weight.g_idx.dtype,
device=weight.g_idx.device,
)
perm = torch.empty(
(n_experts,) + weight.perm.shape,
dtype=weight.perm.dtype,
device=weight.perm.device,
)
moe_weight = GPTQMarlinMoEWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
perm=perm,
is_full_k=weight.is_full_k,
)
moe_weight.qweight[expert] = weight.qweight
moe_weight.qzeros[expert] = weight.qzeros
moe_weight.scales[expert] = weight.scales
moe_weight.g_idx[expert] = weight.g_idx
moe_weight.perm[expert] = weight.perm
return moe_weight

View File

@ -3,13 +3,8 @@ from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import UnquantizedWeight, Weights from text_generation_server.utils.weights import UnquantizedWeight, Weights
from vllm_hpu_extension.ops import DynamicFusedMOE
if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
from moe_kernels.fused_moe import fused_moe
class UnquantizedSparseMoELayer(nn.Module): class UnquantizedSparseMoELayer(nn.Module):
@ -23,6 +18,8 @@ class UnquantizedSparseMoELayer(nn.Module):
topk: int, topk: int,
topk_group: Optional[int], topk_group: Optional[int],
weights: Weights, weights: Weights,
scoring_func: Optional[str] = "softmax",
e_score_correction_bias: Optional[float] = None,
gate_proj_name: str = "gate_proj", gate_proj_name: str = "gate_proj",
up_proj_name: str = "up_proj", up_proj_name: str = "up_proj",
down_proj_name: str = "down_proj", down_proj_name: str = "down_proj",
@ -37,6 +34,9 @@ class UnquantizedSparseMoELayer(nn.Module):
self.topk = topk self.topk = topk
self.topk_group = topk_group self.topk_group = topk_group
self.renormalize = renormalize self.renormalize = renormalize
self.weight_block_size = weights.weights_loader.weight_block_size
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.gate_up_proj = _load_expert_multi_weights_col( self.gate_up_proj = _load_expert_multi_weights_col(
prefix=prefix, prefix=prefix,
@ -53,30 +53,13 @@ class UnquantizedSparseMoELayer(nn.Module):
weights=weights, weights=weights,
) )
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor: self.hpu_fused_moe = DynamicFusedMOE(n_experts)
if SYSTEM == "rocm": for i in range(n_experts):
return fused_moe( self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
x, self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.down_proj[i])
self.gate_up_proj,
self.down_proj,
gating_output,
self.topk,
renormalize=self.renormalize,
inplace=True,
)
return fused_moe( def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
x, return self.hpu_fused_moe(x, gating_output, self.topk)
w1=self.gate_up_proj,
w2=self.down_proj,
gating_output=gating_output,
topk=self.topk,
renormalize=self.renormalize,
inplace=True,
use_grouped_topk=self.n_expert_group is not None,
num_expert_group=self.n_expert_group,
topk_group=self.topk_group,
)
def _load_expert_multi_weights_col( def _load_expert_multi_weights_col(

View File

@ -2,14 +2,10 @@ import os
import math import math
import torch import torch
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
if SYSTEM == "cuda": apply_rotary_pos_emb,
import rotary_emb )
elif SYSTEM == "rocm":
from vllm._C import ops
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
@ -30,7 +26,7 @@ def _get_rope_config(config):
class PositionRotaryEmbedding(nn.Module): class PositionRotaryEmbedding(nn.Module):
def __init__(self, inv_freq, scaling_factor): def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
super().__init__() super().__init__()
self.inv_freq = inv_freq self.inv_freq = inv_freq
self._seq_len_cached = 0 self._seq_len_cached = 0
@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module):
self._sin_k_cached = None self._sin_k_cached = None
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, inv_freq.device, max_position_embeddings
)
def forward( def forward(
self, self,
@ -48,40 +47,41 @@ class PositionRotaryEmbedding(nn.Module):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
): ):
# Such controlflows may add some overhead. num_tokens = query.shape[0]
if SYSTEM == "cuda": head_size = query.shape[-1]
rotary_dim = cos.shape[-1] # HPU RoPE kernel requires hidden dimension for cos and sin to be equal
q1 = query[..., :rotary_dim] # to query hidden dimension, so the original tensors need to be
q2 = query[..., rotary_dim : 2 * rotary_dim] # expanded
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
# and expansion of cos/sin tensors via concatenation
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
cos = torch.cat((cos, cos), dim=-1)
sin = torch.cat((sin, sin), dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) key_shape = key.shape
key = key.view(num_tokens, -1, head_size)
k1 = key[..., :rotary_dim] key_rot = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim] key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False) key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
elif SYSTEM == "rocm":
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
scaling_factor = None scaling_factor = None
rope_scaling = _get_rope_config(config) rope_scaling = _get_rope_config(config)
if not hasattr(config, "max_position_embeddings") and hasattr(
config, "max_seq_len"
):
# handling for dbrx
config.max_position_embeddings = config.max_seq_len
if rope_scaling is not None: if rope_scaling is not None:
# `rope_type` is now standard in transformers, but some existing models # `rope_type` is now standard in transformers, but some existing models
# have `type` instead. # have `type` instead.
@ -89,6 +89,17 @@ class PositionRotaryEmbedding(nn.Module):
if rope_type == "linear": if rope_type == "linear":
pass pass
elif rope_type == "default":
pass
elif rope_type == "mrope":
mrope_section = rope_scaling["mrope_section"]
if mrope_section is not None:
return RotaryPositionEmbeddingMultimodalSections(
inv_freq,
scaling_factor,
mrope_section,
config.max_position_embeddings,
)
elif rope_type == "dynamic": elif rope_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
return DynamicPositionRotaryEmbedding( return DynamicPositionRotaryEmbedding(
@ -109,7 +120,7 @@ class PositionRotaryEmbedding(nn.Module):
], ],
) )
return cls(inv_freq, scaling_factor) return cls(inv_freq, scaling_factor, config.max_position_embeddings)
elif rope_type == "yarn": elif rope_type == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
@ -185,12 +196,13 @@ class PositionRotaryEmbedding(nn.Module):
long_inv_freq=long_inv_freq, long_inv_freq=long_inv_freq,
scaling_factor=scaling_factor, scaling_factor=scaling_factor,
original_max_position_embeddings=original_max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings,
max_position_embeddings=config.max_position_embeddings,
) )
else: else:
raise NotImplementedError( raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid" f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
) )
return cls(inv_freq, scaling_factor) return cls(inv_freq, scaling_factor, config.max_position_embeddings)
@classmethod @classmethod
def load(cls, config, prefix, weights): def load(cls, config, prefix, weights):
@ -236,7 +248,7 @@ class PositionRotaryEmbedding(nn.Module):
raise NotImplementedError( raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid" f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
) )
return cls(inv_freq, scaling_factor) return cls(inv_freq, scaling_factor, config.max_position_embeddings)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -257,17 +269,7 @@ class PositionRotaryEmbedding(nn.Module):
self._cos_cached = torch.cos(freqs).to(dtype) self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype)
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype): def get_cos_sin(self, position_ids: torch.Tensor):
"""
Return cos and sin for the asked position ids
"""
if SYSTEM == "rocm":
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
@ -283,6 +285,7 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
long_inv_freq, long_inv_freq,
scaling_factor, scaling_factor,
original_max_position_embeddings, original_max_position_embeddings,
max_position_embeddings,
): ):
super(PositionRotaryEmbedding, self).__init__() super(PositionRotaryEmbedding, self).__init__()
self.short_inv_freq = short_inv_freq self.short_inv_freq = short_inv_freq
@ -295,6 +298,9 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
@ -348,6 +354,9 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
self._cos_k_cached = None self._cos_k_cached = None
self._sin_k_cached = None self._sin_k_cached = None
self.dynamic_args = None self.dynamic_args = None
self._update_cos_sin_cache(
torch.float32, short_inv_freq.device, max_position_embeddings
)
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
if ( if (
@ -383,7 +392,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(inv_freq, scaling_factor, max_position_embeddings)
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
@ -461,7 +470,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
mscale_all_dim: float, mscale_all_dim: float,
): ):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
super().__init__(inv_freq, scaling_factor) super().__init__(
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
)
self.dim = dim self.dim = dim
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
@ -546,3 +557,50 @@ def apply_llama3_scaling(
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq) new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
def __init__(
self,
inv_freq: torch.Tensor,
scaling_factor: float,
sections: list,
max_position_embeddings,
):
self.sections = sections
self._cos_cached = None
self._sin_cached = None
self.section_indices = (
torch.arange(len(self.sections))
.repeat_interleave(torch.tensor(self.sections))
.view(1, 1, -1)
.to(inv_freq.device)
)
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
def _update_cos_sin_cache(
self, dtype: torch.dtype, device: torch.device, seqlen: int
):
# always cache the cos/sin for the full sequence length to avoid
# recomputing if the sequence length is smaller than the cached one
if (
seqlen > self._seq_len_cached
or self._cos_cached.device != device
or self._cos_cached.dtype != dtype
):
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
self._sections = self.section_indices.expand(seqlen, -1, -1)
def get_cos_sin(
self,
position_ids: torch.Tensor,
):
slen = position_ids.shape[0]
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
return cos, sin

View File

@ -2,10 +2,8 @@ import torch
from torch.nn import functional as F from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "ipex": import habana_frameworks.torch as htorch
import intel_extension_for_pytorch as ipex
class LayerConcat(torch.nn.Module): class LayerConcat(torch.nn.Module):
@ -90,14 +88,10 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
if SYSTEM == "ipex": htorch.core.mark_step()
ipex.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group world_out, gather_input, group=self.process_group
) )
else:
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
if input.shape[0] == 1: if input.shape[0] == 1:
return world_out return world_out
@ -107,10 +101,9 @@ class TensorParallelHead(SuperLayer):
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())
] ]
if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group) htorch.core.mark_step()
else: torch.distributed.all_gather(world_output, output, group=self.process_group)
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1) world_output = torch.cat(world_output, dim=-1)
return world_output return world_output
@ -202,10 +195,11 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1 and reduce: if self.process_group.size() > 1 and reduce:
if SYSTEM == "ipex": # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
ipex.distributed.all_reduce(out, group=self.process_group) # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
else: # (which is required for tensor parallel HPUGraph inference)
torch.distributed.all_reduce(out, group=self.process_group) htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module):
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1: if self.reduce and self.process_group.size() > 1:
if SYSTEM == "ipex": # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
ipex.distributed.all_reduce(out, group=self.process_group) # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
else: # (which is required for tensor parallel HPUGraph inference)
torch.distributed.all_reduce(out, group=self.process_group) htorch.core.mark_step()
torch.distributed.all_reduce(out, group=self.process_group)
return out return out

View File

@ -1,3 +1,5 @@
# ruff: noqa: F821
# the above line disables the `undefined-name` rule for the model type variables
import torch import torch
import os import os
@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi
from typing import Optional from typing import Optional
from pathlib import Path from pathlib import Path
from typing import List, Dict from typing import List, Dict
import enum
# Needed to properly setup habana_frameworks # Needed to properly setup habana_frameworks
@ -16,15 +19,10 @@ from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.custom_modeling.flash_phi_moe_modeling import (
from text_generation_server.models.custom_modeling.mllama import ( PhiMoEConfig,
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
) )
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.utils.adapter import ( from text_generation_server.utils.adapter import (
AdapterParameters, AdapterParameters,
build_layer_weight_lookup, build_layer_weight_lookup,
@ -33,9 +31,285 @@ from text_generation_server.utils.adapter import (
) )
from text_generation_server.adapters.lora import LoraWeights from text_generation_server.adapters.lora import LoraWeights
from text_generation_server.utils.log import log_master
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
__all__ = [
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model_with_lora_adapters",
]
from text_generation_server.models.globals import ATTENTION
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
FLASH_ATTENTION = False
if ATTENTION == "paged":
FLASH_ATTENTION = True
try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_vlm_causal_lm import FlashVlmCausalLM
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLM
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
FlashDeepseekV2ForCausalLM,
DeepseekV2Config,
)
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
FlashDeepseekV3ForCausalLM,
DeepseekV3Config,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
FlashCohereForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
FlashDbrxForCausalLM,
DbrxConfig,
)
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
RWConfig,
FlashRWForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
FlashGPTNeoXForCausalLM,
)
from text_generation_server.models.pali_gemma import (
PaliGemmaBatch,
)
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
PaliGemmaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
FlashPhiForCausalLM,
)
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.custom_modeling.flash_mllama import (
FlashMllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_llava_next import (
FlashLlavaNextForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
FlashStarcoder2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
FlashMixtralForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
FlashGPT2ForCausalLM,
)
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
FlashGPTJForCausalLM,
)
from text_generation_server.models.custom_modeling.idefics2 import (
Idefics2ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.idefics3 import (
Idefics3ForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_vl import (
Qwen2VLForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
Qwen2_5VLForConditionalGeneration,
Qwen2_5_VLConfig,
Qwen2_5_VLProcessor,
)
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
except ImportError as e:
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
SUPPORTS_WINDOWING = False
FLASH_ATTENTION = False
if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
class ModelType(enum.Enum):
DEEPSEEK_V2 = {
"type": "deepseek_v2",
"name": "Deepseek V2",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
}
DEEPSEEK_V3 = {
"type": "deepseek_v3",
"name": "Deepseek V3",
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
}
IDEFICS2 = {
"type": "idefics2",
"name": "Idefics 2",
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
"multimodal": True,
}
IDEFICS3 = {
"type": "idefics3",
"name": "Idefics 3",
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
"multimodal": True,
}
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GRANITE = {
"type": "granite",
"name": "Granite",
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
PALIGEMMA = {
"type": "paligemma",
"name": "PaliGemma",
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
}
GEMMA2 = {
"type": "gemma2",
"name": "Gemma2",
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
}
COHERE = {
"type": "cohere",
"name": "Cohere",
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
}
DBRX = {
"type": "dbrx",
"name": "Dbrx",
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "mamba",
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
}
MIXTRAL = {
"type": "mixtral",
"name": "Mixtral",
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
}
GPT_BIGCODE = {
"type": "gpt_bigcode",
"name": "Gpt Bigcode",
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
PHI_MOE = {
"type": "phimoe",
"name": "PhiMoe",
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
FALCON = {
"type": "falcon",
"name": "Falcon",
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
}
STARCODER2 = {
"type": "starcoder2",
"name": "StarCoder 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
}
QWEN2_VL = {
"type": "qwen2_vl",
"name": "Qwen 2 VL",
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
}
QWEN2_5_VL = {
"type": "qwen2_5_vl",
"name": "Qwen 2.5 VL",
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
}
GALACTICA = {
"type": "galactica",
"name": "Galactica",
"url": "https://huggingface.co/facebook/galactica-120b",
}
SANTACODER = {
"type": "santacoder",
"name": "SantaCoder",
"url": "https://huggingface.co/bigcode/santacoder",
}
GPT2 = {
"type": "gpt2",
"name": "Gpt2",
"url": "https://huggingface.co/openai-community/gpt2",
}
GPT_NEOX = {
"type": "gpt_neox",
"name": "Gpt Neox",
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
}
GPTJ = {
"type": "gptj",
"name": "Gptj",
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}
__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]
SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0)) SDP_ON_BF16 = int(os.environ.get("SDP_ON_BF16", 0))
# Disable gradients # Disable gradients
@ -53,9 +327,7 @@ def get_model(
trust_remote_code: bool, trust_remote_code: bool,
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
adapt_transformers_to_gaudi() global FLASH_ATTENTION
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if speculate is not None: if speculate is not None:
set_speculate(speculate) set_speculate(speculate)
@ -177,9 +449,393 @@ def get_model(
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
kv_cache_dtype = dtype
if FLASH_ATTENTION:
if model_type == DEEPSEEK_V2:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV2Config,
head_size=head_size,
)
elif model_type == DEEPSEEK_V3:
head_size = max(
config_dict.get("qk_nope_dim", 128)
+ config_dict.get("qk_rope_dim", 64),
config_dict.get("v_head_dim", 128),
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashDeepseekV3ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
default_dtype=torch.bfloat16,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DeepseekV3Config,
head_size=head_size,
)
elif (
model_type == GPT_BIGCODE
or model_type == GPT2
and model_id.startswith("bigcode/")
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashSantacoderForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
aliases={"transformer.wte.weight": ["lm_head.weight"]},
num_kv_heads=1,
)
elif model_type == GPT2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPT2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPTJ:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTJForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GPT_NEOX:
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
GPTNeoXConfig,
)
return FlashCausalLM(
model_id=model_id,
model_class=FlashGPTNeoXForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=GPTNeoXConfig,
)
elif model_type == PHI:
return FlashCausalLM(
model_id=model_id,
model_class=FlashPhiForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == PHI_MOE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
config_class=PhiMoEConfig,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == BAICHUAN:
return FlashCausalLM(
model_id=model_id,
model_class=FlashLlamaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemmaForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == GEMMA2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashGemma2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == COHERE:
return FlashCausalLM(
model_id=model_id,
model_class=FlashCohereForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == DBRX:
return FlashCausalLM(
model_id=model_id,
model_class=FlashDbrxForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Dbrx works better in bfloat16.
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=DbrxConfig,
)
elif (
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
and not sharded
and not config_dict.get("alibi", False)
):
return FlashCausalLM(
model_id=model_id,
model_class=FlashRWForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
aliases={
"lm_head.weight": ["transformer.word_embeddings.weight"],
"transformer.word_embeddings.weight": ["lm_head.weight"],
},
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=RWConfig,
)
elif model_type == MISTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMistralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == MIXTRAL:
return FlashCausalLM(
model_id=model_id,
model_class=FlashMixtralForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == STARCODER2:
return FlashCausalLM(
model_id=model_id,
model_class=FlashStarcoder2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2:
return FlashCausalLM(
model_id=model_id,
model_class=Qwen2ForCausalLM,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == QWEN2_5_VL:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Qwen2_5VLForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
config_class=Qwen2_5_VLConfig,
processor_class=Qwen2_5_VLProcessor,
)
elif model_type == MLLAMA:
return FlashMllamaCausalLM(
model_id=model_id,
model_class=FlashMllamaForConditionalGeneration,
batch_class=FlashMllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)
elif model_type == IDEFICS2:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics2ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
)
elif model_type == IDEFICS3:
return FlashVlmCausalLM(
model_id=model_id,
model_class=Idefics3ForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
# XXX: Extremely important to cap resolution in order to limit
# VRAM usage.
processor_kwargs={"size": {"longest_edge": 1456}},
)
elif model_type == PALIGEMMA:
return FlashVlmCausalLM(
model_id=model_id,
model_class=PaliGemmaForConditionalGeneration,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
# Works better for these models
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
batch_class=PaliGemmaBatch,
)
elif model_type == LLAVA_NEXT:
return FlashVlmCausalLM(
model_class=FlashLlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
trust_remote_code=trust_remote_code,
)
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
adapt_transformers_to_gaudi()
if SDP_ON_BF16 == 1:
torch._C._set_math_sdp_allow_fp16_bf16_reduction(True)
if model_type == "gpt_bigcode": if model_type == "gpt_bigcode":
return StarCoder(model_id=model_id, revision=revision, dtype=dtype) return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
if model_type == "bloom": if model_type == "bloom":
return BLOOM( return BLOOM(
model_id=model_id, model_id=model_id,

View File

@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
past_value.view(-1, *past_value.shape[-2:]), past_value.view(-1, *past_value.shape[-2:]),
) )
if CUSTOM_KERNELS_ENABLED: if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
assert self.training is False, "Only foward pass was implemented" assert self.training is False, "Only foward pass was implemented"
assert ( assert (
attention_mask.shape[-1] < 4096 attention_mask.shape[-1] < 4096
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
@staticmethod @staticmethod
def _convert_to_bloom_cache( def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]] past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
""" """
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))

View File

@ -28,10 +28,10 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
@ -39,7 +39,6 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -47,11 +46,10 @@ from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.utils.weights import UnquantizedWeight from text_generation_server.utils.weights import UnquantizedWeight
from habana_frameworks.torch.hpex.kernels import (
if SYSTEM == "cuda": RotaryPosEmbeddingMode,
import dropout_layer_norm apply_rotary_pos_emb,
else: )
dropout_layer_norm = None
class CohereRotary(PositionRotaryEmbedding): class CohereRotary(PositionRotaryEmbedding):
@ -63,38 +61,25 @@ class CohereRotary(PositionRotaryEmbedding):
sin: torch.Tensor, sin: torch.Tensor,
): ):
# Such controlflows may add some overhead. # Such controlflows may add some overhead.
if SYSTEM == "cuda": num_tokens = query.shape[0]
import rotary_emb head_size = query.shape[-1]
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
q1 = query[..., ::2] key_shape = key.shape
q2 = query[..., 1::2] key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
k1 = key[..., ::2] key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class CohereLayerNorm(nn.Module): class CohereLayerNorm(nn.Module):
@ -107,49 +92,18 @@ class CohereLayerNorm(nn.Module):
self.eps = eps self.eps = eps
def forward(self, hidden_states): def forward(self, hidden_states):
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda": hidden_states = hidden_states.reshape(
hidden_states = hidden_states.reshape(
-1, self.weight.shape[0], self.weight.shape[1]
)
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
(
hidden_states,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.ones,
None,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
# Required to apply one weight matrix per head
hidden_states = hidden_states.view(
-1, self.weight.shape[0], self.weight.shape[1] -1, self.weight.shape[0], self.weight.shape[1]
) )
hidden_states = self.weight * hidden_states input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True)
hidden_states_minus_mean = hidden_states - mean
variance = hidden_states_minus_mean.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states_minus_mean * torch.rsqrt(variance + self.eps)
hidden_states = self.weight.to(torch.float32) * hidden_states
hidden_states = hidden_states.view(-1, self.weight.shape[1]) hidden_states = hidden_states.view(-1, self.weight.shape[1])
return hidden_states.to(input_dtype)
return hidden_states
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
@ -229,6 +183,7 @@ class FlashCohereAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.use_qk_norm = config.use_qk_norm self.use_qk_norm = config.use_qk_norm
if self.use_qk_norm: if self.use_qk_norm:
@ -264,10 +219,9 @@ class FlashCohereAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, key, value = qkv.split( query, key, value = qkv.split(
@ -291,30 +245,35 @@ class FlashCohereAttention(torch.nn.Module):
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj( return self.o_proj(
@ -386,10 +345,9 @@ class FlashCohereLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -400,10 +358,9 @@ class FlashCohereLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
mlp_output = self.mlp(normed_hidden_states) mlp_output = self.mlp(normed_hidden_states)
@ -452,18 +409,15 @@ class FlashCohereModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: torch.Tensor, seqlen: torch.Tensor,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
@ -475,10 +429,9 @@ class FlashCohereModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -516,11 +469,9 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -529,10 +480,9 @@ class FlashCohereForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -20,17 +20,14 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention.kv_cache import get_kv_scales
if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
@ -46,6 +43,7 @@ from text_generation_server.layers.rotary import (
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from vllm_hpu_extension.ops import DynamicFusedMOE
class DbrxAttentionConfig(PretrainedConfig): class DbrxAttentionConfig(PretrainedConfig):
@ -290,6 +288,7 @@ class DbrxAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -309,10 +308,9 @@ class DbrxAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
if self.clip_qkv is not None: if self.clip_qkv is not None:
@ -330,30 +328,35 @@ class DbrxAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -387,10 +390,9 @@ class DbrxNormAttentionNorm(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
normed_hidden_states, res = self.norm_1(hidden_states, residual) normed_hidden_states, res = self.norm_1(hidden_states, residual)
@ -401,10 +403,9 @@ class DbrxNormAttentionNorm(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -482,18 +483,15 @@ class BlockSparseMoE(nn.Module):
self.process_group = weights.process_group self.process_group = weights.process_group
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
for i in range(self.num_experts):
self.hpu_fused_moe.MoeOp.w13_list[i].set_weight(self.wv1[i])
self.hpu_fused_moe.MoeOp.w2_list[i].set_weight(self.w2[i])
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(x) router_logits = self.gate(x)
out = fused_moe( out = self.hpu_fused_moe(x, router_logits, self.top_k)
x,
self.wv1,
self.w2,
router_logits,
self.top_k,
renormalize=self.moe_normalize_expert_weights,
inplace=True,
)
# Reduce sum # Reduce sum
if self.process_group.size() > 1: if self.process_group.size() > 1:
@ -620,10 +618,9 @@ class DbrxLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
# Self Attention # Self Attention
attn_output, attn_res = self.attn( attn_output, attn_res = self.attn(
@ -633,10 +630,9 @@ class DbrxLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
moe_output = self.moe(attn_output) moe_output = self.moe(attn_output)
@ -677,18 +673,15 @@ class DbrxModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].attn.self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -699,10 +692,9 @@ class DbrxModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -732,11 +724,9 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -745,10 +735,9 @@ class FlashDbrxForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -33,21 +33,14 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.weights import Weights from text_generation_server.utils.weights import Weights
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class DeepseekV2Config(PretrainedConfig): class DeepseekV2Config(PretrainedConfig):
def __init__( def __init__(
@ -232,6 +225,8 @@ class DeepseekV2Attention(torch.nn.Module):
), ),
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load( self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
) )
@ -260,11 +255,10 @@ class DeepseekV2Attention(torch.nn.Module):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor], kv_cache: KVCache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
if self.q_lora_rank is None: if self.q_lora_rank is None:
query = self.q_proj(hidden_states) query = self.q_proj(hidden_states)
@ -321,30 +315,35 @@ class DeepseekV2Attention(torch.nn.Module):
value, (0, self.head_pad_size - self.value_head_size), value=0 value, (0, self.head_pad_size - self.value_head_size), value=0
) )
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
# Remove padding. # Remove padding.
@ -387,27 +386,11 @@ class DeepseekV2MLP(nn.Module):
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True): def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
if ( gate_up_states = self.gate_up_proj(hidden_states)
SYSTEM == "rocm" gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
and self.hidden_act == "silu" return self.down_proj(
and hidden_states.dtype == torch.float16 self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
and hidden_states.shape[0] == 1 )
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out, reduce=reduce)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
@ -520,10 +503,9 @@ class DeepseekV2Layer(nn.Module):
sin: torch.Tensor, sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor, cu_seqlen_prefill: torch.Tensor,
kv_cache, kv_cache,
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual) normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -534,10 +516,9 @@ class DeepseekV2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -583,18 +564,15 @@ class DeepseekV2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -605,10 +583,9 @@ class DeepseekV2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -635,11 +612,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -648,10 +623,9 @@ class FlashDeepseekV2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -0,0 +1,642 @@
# coding=utf-8
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Type
import torch
import torch.distributed
from torch import nn
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from text_generation_server.layers import (
FastLinear,
SpeculativeHead,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
get_linear,
)
from text_generation_server.layers.attention import (
Seqlen,
attention,
paged_attention,
HPUPagedAttentionMetadata,
)
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
from text_generation_server.utils.weights import Weights
class DeepseekV3Config(PretrainedConfig):
def __init__(
self,
vocab_size=102400,
hidden_size=4096,
intermediate_size=11008,
moe_intermediate_size=1407,
num_hidden_layers=30,
num_attention_heads=32,
num_key_value_heads=32,
n_shared_experts=2,
n_routed_experts=160,
ep_size=1,
routed_scaling_factor=1.0,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method="gready",
n_group=8,
topk_group=3,
num_experts_per_tok=6,
moe_layer_freq=1,
first_k_dense_replace=0,
norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
seq_aux=True,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=100000,
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
self.aux_loss_alpha = aux_loss_alpha
self.seq_aux = seq_aux
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
if tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Deepseek V2 models."
)
if ep_size != 1:
raise ValueError(
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class DeepseekV3Attention(torch.nn.Module):
def __init__(
self,
prefix: str,
config,
weights: Weights,
):
super().__init__()
self.num_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_rope_head_dim = config.qk_rope_head_dim
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
self.value_head_size = config.v_head_dim
self.head_pad_size = max(self.head_size, self.value_head_size)
self.rotary_emb = PositionRotaryEmbedding.static(
config=config,
dim=self.qk_rope_head_dim,
base=config.rope_theta,
device=weights.device,
)
mscale = get_mscale(
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
)
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
)
if self.q_lora_rank is None:
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=config.attention_bias,
)
else:
self.q_a_proj = get_linear(
weight=weights.get_weights(f"{prefix}.q_a_proj"),
bias=(
weights.get_tensor(f"{prefix}.q_a_proj.bias")
if config.attention_bias
else None
),
)
self.q_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.q_a_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.q_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.kv_a_proj_with_mqa = get_linear(
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
bias=(
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
if config.attention_bias
else None
),
)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_a_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.kv_b_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.kv_b_proj",
weights=weights,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def forward(
self,
hidden_states: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache: KVCache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
if self.q_lora_rank is None:
query = self.q_proj(hidden_states)
else:
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
query = query.view(-1, self.num_heads, self.head_size)
_, query_pe = torch.split(
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
compressed_kv, key_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
)
key_nope, value = torch.split(
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
)
batch_size, heads, head_dim = query_pe.shape
query_pe = (
query_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
batch_size, heads, head_dim = key_pe.shape
key_pe = (
key_pe.view(batch_size, heads, head_dim // 2, 2)
.transpose(2, 3)
.reshape(batch_size, heads, head_dim)
)
self.rotary_emb(query_pe, key_pe, cos, sin)
query[..., self.qk_nope_head_dim :] = query_pe
key = torch.empty_like(query)
key[..., : self.qk_nope_head_dim] = key_nope
key[..., self.qk_nope_head_dim :] = key_pe
# We need to pad the heads because Flash Attention does not support
# qk and v with different head sizes.
query = torch.nn.functional.pad(
query, (0, self.head_pad_size - self.head_size), value=0
)
key = torch.nn.functional.pad(
key, (0, self.head_pad_size - self.head_size), value=0
)
value = torch.nn.functional.pad(
value, (0, self.head_pad_size - self.value_head_size), value=0
)
kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill
if cu_seqlen_prefill is not None:
# flash attention
attn_output = attention(
query=query,
key=key,
value=value,
kv_cache=kv_cache,
kv_scales=self.kv_scales,
seqlen=seqlen,
softmax_scale=self.softmax_scale,
)
# Decode
else:
attn_output = paged_attention(
query,
kv_cache,
self.kv_head_mapping,
self.softmax_scale,
seqlen,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
)
# Remove padding.
attn_output = attn_output[..., : self.value_head_size]
return self.o_proj(
attn_output.reshape(-1, self.num_heads * self.value_head_size)
)
class DeepseekV3MLP(nn.Module):
def __init__(self, prefix: str, config, weights, intermediate_size: int):
super().__init__()
self.hidden_act = config.hidden_act
if self.hidden_act != "silu":
# Bail out because MoE only supports silu.
raise NotImplementedError(
"Currently only `silu` is supported as an activation for Deepseek V2."
)
self.act = ACT2FN[self.hidden_act]
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.intermediate_size = intermediate_size // weights.process_group.size()
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
)
class DeepseekV3MoE(nn.Module):
def __init__(
self,
prefix,
config: DeepseekV3Config,
moe_layer_cls: Type[MoELayer],
weights,
):
super().__init__()
self.hidden_dim = config.hidden_size
self.moe_intermediate_size = (
config.moe_intermediate_size // weights.process_group.size()
)
self.routed_scaling_factor = config.routed_scaling_factor
# Gating
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = torch.zeros(
config.n_routed_experts, device=weights.device
)
else:
self.gate.e_score_correction_bias = None
self.moe_layer = moe_layer_cls(
prefix=f"{prefix}.experts",
n_experts=config.n_routed_experts,
n_expert_group=config.n_group,
renormalize=config.norm_topk_prob,
topk=config.num_experts_per_tok,
topk_group=config.topk_group,
weights=weights,
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,
)
assert isinstance(self.moe_layer, MoELayer)
if config.n_shared_experts is not None:
self.shared_experts = DeepseekV3MLP(
prefix=f"{prefix}.shared_experts",
config=config,
weights=weights,
intermediate_size=config.moe_intermediate_size
* config.n_shared_experts,
)
else:
self.shared_experts = None
self.process_group = weights.process_group
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.shared_experts is not None:
shared_output = self.shared_experts(x, reduce=False)
else:
shared_output = None
router_logits = self.gate(x)
out = self.moe_layer(x, gating_output=router_logits)
if shared_output is not None:
out = out + shared_output
# Reduce sum
if self.process_group.size() > 1:
torch.distributed.all_reduce(out, group=self.process_group)
return out.view(*x.shape)
class DeepseekV3Layer(nn.Module):
def __init__(self, prefix, layer_id, config, weights):
super().__init__()
prefix = f"{prefix}.layers.{layer_id}"
self.self_attn = DeepseekV3Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
and layer_id % config.moe_layer_freq == 0
):
moe_layer_cls = (
SparseMoELayer
if SparseMoELayer.is_supported(weights)
else DenseMoELayer
)
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
else:
self.mlp = DeepseekV3MLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
intermediate_size=config.intermediate_size,
)
self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
def forward(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
cu_seqlen_prefill: torch.Tensor,
kv_cache,
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
):
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
attn_output = self.self_attn(
normed_hidden_states,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
# faster post attention rms norm
normed_attn_res_output, residual = self.post_attention_layernorm(
attn_output, residual
)
output = self.mlp(normed_attn_res_output)
return output, residual
class DeepseekV3Model(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList(
[
DeepseekV3Layer(
prefix,
layer_id,
config,
weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastRMSNorm.load(
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
)
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache[i],
slots,
seqlen,
hpu_attention_meta,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights: Weights):
super().__init__()
self.model = DeepseekV3Model(
"model" if not prefix else f"{prefix}.model", config, weights
)
self.lm_head = SpeculativeHead.load(
config,
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
weights=weights,
)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -28,8 +28,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -40,7 +40,7 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -208,6 +208,7 @@ class FlashGemma2Attention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -234,11 +235,10 @@ class FlashGemma2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
@ -253,19 +253,24 @@ class FlashGemma2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
causal=self.causal, softmax_scale=self.softmax_scale,
window_size_left=self.window_size, window_size_left=self.window_size,
softcap=self.softcap, softcap=self.softcap,
) )
@ -273,14 +278,13 @@ class FlashGemma2Attention(torch.nn.Module):
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s,
softcap=self.softcap, softcap=self.softcap,
kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj( return self.o_proj(
@ -390,11 +394,10 @@ class FlashGemma2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -405,11 +408,10 @@ class FlashGemma2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -458,19 +460,16 @@ class FlashGemma2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, adapter_data: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -481,11 +480,10 @@ class FlashGemma2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -529,11 +527,9 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -543,11 +539,10 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -28,9 +28,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
PREFILL_IN_KV_CACHE, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -39,6 +38,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -187,6 +187,7 @@ class FlashGemmaAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -206,10 +207,9 @@ class FlashGemmaAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -224,31 +224,36 @@ class FlashGemmaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
causal=self.causal, causal=self.causal,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -317,10 +322,9 @@ class FlashGemmaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -331,10 +335,9 @@ class FlashGemmaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -379,18 +382,16 @@ class FlashGemmaModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, adapter_data: Optional[torch.Tensor],
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -401,10 +402,9 @@ class FlashGemmaModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -446,11 +446,9 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -460,10 +458,10 @@ class FlashGemmaForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -24,12 +24,11 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -38,6 +37,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention.kv_cache import get_kv_scales
def load_qkv(config, prefix: str, weights, head_size, num_heads): def load_qkv(config, prefix: str, weights, head_size, num_heads):
@ -47,10 +47,6 @@ def load_qkv(config, prefix: str, weights, head_size, num_heads):
prefix, prefix,
weights, weights,
) )
elif config.quantize == "marlin":
raise RuntimeError(
"GPT-2 models with marlin quantization are not yet supported"
)
else: else:
return _load_qkv(config, prefix, weights, head_size, num_heads) return _load_qkv(config, prefix, weights, head_size, num_heads)
@ -195,6 +191,7 @@ class FlashGPT2Attention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
num_heads=self.num_heads, num_heads=self.num_heads,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -212,10 +209,9 @@ class FlashGPT2Attention(torch.nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1 self.head_size * self.num_heads, dim=1
@ -224,30 +220,35 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size) value = value.view(-1, self.num_heads, self.head_size)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -313,10 +314,9 @@ class FlashGPT2Layer(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -326,10 +326,9 @@ class FlashGPT2Layer(nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states = attn_output + residual hidden_states = attn_output + residual
@ -379,12 +378,9 @@ class FlashGPT2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -395,10 +391,9 @@ class FlashGPT2Model(torch.nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
@ -432,11 +427,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -448,12 +441,9 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta=hpu_attention_meta,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -24,12 +24,12 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -38,13 +38,16 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import ( from text_generation_server.layers.rotary import (
PositionRotaryEmbedding, PositionRotaryEmbedding,
) )
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
from habana_frameworks.torch.hpex.kernels import (
RotaryPosEmbeddingMode,
apply_rotary_pos_emb,
)
def load_attention(config, prefix: str, weights): def load_attention(config, prefix: str, weights):
@ -78,39 +81,25 @@ class GPTJRotary(PositionRotaryEmbedding):
cos: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, sin: torch.Tensor,
): ):
# Such controlflows may add some overhead. num_tokens = query.shape[0]
if SYSTEM == "cuda": head_size = query.shape[-1]
import rotary_emb rope_mode = RotaryPosEmbeddingMode.PAIRWISE
sin = torch.repeat_interleave(sin, 2, dim=-1)
cos = torch.repeat_interleave(cos, 2, dim=-1)
rotary_dim = cos.shape[-1]
query_shape = query.shape
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
q1 = query[..., ::2] key_shape = key.shape
q2 = query[..., 1::2] key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False) key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
k1 = key[..., ::2] key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
k2 = key[..., 1::2]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif SYSTEM == "rocm":
from vllm._C import ops
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, False)
elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), False
)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
)
class FlashGPTJAttention(torch.nn.Module): class FlashGPTJAttention(torch.nn.Module):
@ -140,6 +129,7 @@ class FlashGPTJAttention(torch.nn.Module):
prefix=prefix, prefix=prefix,
weights=weights, weights=weights,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = load_row( self.o_proj = load_row(
config, config,
@ -166,10 +156,9 @@ class FlashGPTJAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
query, key, value = self.query_key_value(hidden_states).split( query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1 self.head_size * self.num_heads, dim=1
@ -186,30 +175,35 @@ class FlashGPTJAttention(torch.nn.Module):
else: else:
self.rotary_emb(query, key, cos, sin) self.rotary_emb(query, key, cos, sin)
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=key,
value=value,
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key, key=key,
kv_cache[1] if PREFILL_IN_KV_CACHE else value, value=value,
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -267,10 +261,9 @@ class FlashGPTJLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention # Self Attention
@ -280,10 +273,9 @@ class FlashGPTJLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
feed_forward_hidden_states = self.mlp(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states)
@ -327,19 +319,15 @@ class FlashGPTJModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -350,10 +338,9 @@ class FlashGPTJModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -381,11 +368,9 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -394,11 +379,9 @@ class FlashGPTJForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta=hpu_attention_meta,
prefill_cache_indices=prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -27,14 +27,16 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention import (
KVCache,
get_kv_scales,
)
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -57,15 +59,6 @@ from text_generation_server.utils.weights import (
) )
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
if SYSTEM != "ipex":
pass
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
def load_attention(config, prefix: str, weights, layer_id): def load_attention(config, prefix: str, weights, layer_id):
# Only defined in granite. # Only defined in granite.
@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
self.softmax_scale = self.head_size**-0.5 # `config.attention_multiplier` is used in Granite
self.softmax_scale = getattr(
config, "attention_multiplier", self.head_size**-0.5
)
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=getattr(config, "attention_bias", False),
) )
self.o_proj = TensorParallelAdapterRowLinear.load( self.o_proj = TensorParallelAdapterRowLinear.load(
@ -202,12 +200,11 @@ class FlashLlamaAttention(torch.nn.Module):
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache: KVCache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
@ -222,30 +219,35 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_scales=self.kv_scales,
block_tables, kv_cache=kv_cache,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj( return self.o_proj(
@ -363,31 +365,11 @@ class LlamaMLP(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
if ( gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
SYSTEM == "rocm" gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
and self.hidden_act == "silu" return self.down_proj(
and hidden_states.dtype == torch.float16 self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
and hidden_states.shape[0] == 1 )
and not self.quantize
and self.hidden_size
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
@ -408,7 +390,7 @@ class FlashLlamaLayer(nn.Module):
if SparseMoELayer.is_supported(weights) if SparseMoELayer.is_supported(weights)
else DenseMoELayer else DenseMoELayer
) )
self.dense = Phi3MoE( self.mlp = Phi3MoE(
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
) )
# with moe the layernorms are are not rmsnorms and they have bias # with moe the layernorms are are not rmsnorms and they have bias
@ -423,7 +405,7 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
else: else:
self.dense = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
) )
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
@ -437,6 +419,11 @@ class FlashLlamaLayer(nn.Module):
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
# Used in Granite
# This could eventually be baked into the weights like we do for the embeddings/lm_head
# but this would mean modifying the lora code
self.residual_multiplier = getattr(config, "residual_multiplier", None)
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -445,12 +432,11 @@ class FlashLlamaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -461,19 +447,21 @@ class FlashLlamaLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
hpu_attention_meta=hpu_attention_meta,
) )
if self.residual_multiplier is not None:
attn_output *= self.residual_multiplier
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res attn_output, res
) )
mlp_output = self.dense(normed_attn_res_output, adapter_data) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
if self.residual_multiplier is not None:
mlp_output *= self.residual_multiplier
return mlp_output, attn_res return mlp_output, attn_res
@ -493,9 +481,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=f"{prefix}.layers.0",
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -504,18 +490,14 @@ class FlashLlamaModel(torch.nn.Module):
# Skip first and last layers # Skip first and last layers
for layer_id in range(1, config.num_hidden_layers - 1): for layer_id in range(1, config.num_hidden_layers - 1):
if layer_id in self.cross_attention_layers: if layer_id in self.cross_attention_layers:
from text_generation_server.models.custom_modeling.mllama import ( from text_generation_server.models.custom_modeling.flash_mllama import (
FlashLlamaCrossLayer, FlashLlamaCrossLayer,
) )
self.layers.append( self.layers.append(
FlashLlamaCrossLayer( FlashLlamaCrossLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -524,11 +506,7 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=layer_id, index=layer_id,
prefix=( prefix=(f"{prefix}.layers.{layer_id}"),
f"model.layers.{layer_id}"
if not prefix
else f"{prefix}.model.layers.{layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
@ -539,18 +517,14 @@ class FlashLlamaModel(torch.nn.Module):
self.layers.append( self.layers.append(
FlashLlamaLayer( FlashLlamaLayer(
index=last_layer_id, index=last_layer_id,
prefix=( prefix=(f"{prefix}.layers.{last_layer_id}"),
f"model.layers.{last_layer_id}"
if not prefix
else f"{prefix}.model.layers.{last_layer_id}"
),
config=config, config=config,
weights=weights, weights=weights,
) )
) )
self.norm = FastRMSNorm.load( self.norm = FastRMSNorm.load(
prefix="model.norm" if not prefix else f"{prefix}.model.norm", prefix=f"{prefix}.norm",
weights=weights, weights=weights,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
) )
@ -567,22 +541,17 @@ class FlashLlamaModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
cross_attention_states=None, cross_attention_states=None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -593,12 +562,11 @@ class FlashLlamaModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
adapter_data, adapter_data,
cross_attention_states, cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -607,42 +575,60 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights): def __init__(self, prefix: str, config, weights, name=None):
if name is None:
name = "model"
super().__init__() super().__init__()
with no_fp8(weights): with no_fp8(weights):
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
prefix=( prefix=(
"model.embed_tokens" f"{name}.embed_tokens"
if not prefix if not prefix
else f"{prefix}.model.embed_tokens" else f"{prefix}.{name}.embed_tokens"
), ),
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel(prefix, config, weights) self.model = FlashLlamaModel(
prefix=name if not prefix else f"{prefix}.{name}",
config=config,
weights=weights,
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
suffix = "lm_head" suffix = "lm_head"
# Used in Granite
embedding_multiplier = getattr(config, "embedding_multiplier", None)
if embedding_multiplier is not None:
self.embed_tokens.weight.data *= embedding_multiplier
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
with no_fp8(weights): with no_fp8(weights):
self.lm_head = SpeculativeHead.load( self.lm_head = SpeculativeHead.load(
config, config,
prefix=suffix if not prefix else f"{prefix}.{suffix}", prefix,
weights=weights, weights,
) )
# Used in Granite
self.logits_scaling = getattr(config, "logits_scaling", None)
if self.logits_scaling is not None and self.lm_head.head is not None:
try:
# Scale the weights directly
self.lm_head.head.linear.weight.data /= self.logits_scaling
self.logits_scaled = True
except Exception:
self.logits_scaled = False
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
cross_attention_states=None, cross_attention_states=None,
@ -653,16 +639,20 @@ class FlashLlamaForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
cross_attention_states=cross_attention_states, cross_attention_states=cross_attention_states,
hpu_attention_meta=hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states) logits, speculative_logits = self.lm_head(hidden_states)
# Used in Granite
if self.logits_scaling is not None and not self.logits_scaled:
logits /= self.logits_scaling
if speculative_logits is not None:
speculative_logits /= self.logits_scaling
return logits, speculative_logits return logits, speculative_logits

View File

@ -0,0 +1,285 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Llava-NeXT model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
)
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
)
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (height, width).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (`torch.Tensor`):
The image tensor, assumed to be of shape (num_channels, height, width).
original_size (`tuple`):
The original size of the image (height, width).
Returns:
`torch.Tensor`: The unpadded image tensor.
"""
original_height, original_width = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
class LlavaNextMultiModalProjector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.linear_1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
)
self.act = ACT2FN[config.projector_hidden_act]
self.linear_2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class FlashLlavaNextForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
vision_config = config.vision_config
# Instead of selecting in hidden_states[-2].
# Instead compute only the n -2 + 1 layers and don't pool
if config.vision_feature_layer < 0:
vision_config.num_hidden_layers += config.vision_feature_layer + 1
else:
vision_config.num_hidden_layers = config.vision_feature_layer + 1
self.vision_tower = load_vision_model(
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
config=config.vision_config,
weights=weights,
)
self.multi_modal_projector = LlavaNextMultiModalProjector(
prefix="multi_modal_projector", config=config, weights=weights
)
self.image_newline = weights.get_tensor("image_newline")
self.vocab_size = config.text_config.vocab_size
self.config = config
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
self.text_model = load_text_model(
prefix="language_model" if not prefix else f"{prefix}.language_model",
config=config.text_config,
weights=weights,
)
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
mask = torch.where(input_ids == self.config.image_token_index)
# Let's pray we have enabled enough slots !
try:
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
except Exception as e:
raise RuntimeError(
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
# 1. Extract the input embeddings
# 2. Merge text and images
num_images, num_patches, channels, height, width = pixel_values.shape
pixel_values = pixel_values.view(
num_images * num_patches, channels, height, width
)
image_features = self.vision_tower(pixel_values)
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
# Already done within the clip model
selected_image_feature = image_features.last_hidden_state
if self.config.vision_feature_select_strategy == "default":
selected_image_feature = selected_image_feature[:, 1:]
elif self.config.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise RuntimeError(
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
)
image_features = self.multi_modal_projector(selected_image_feature)
# split up image_features for each of the individual images
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
# if we assume each image has 5 image features (base image + 4 patches)
split_sizes = [num_patches] * num_images
image_features = torch.split(image_features, split_sizes, dim=0)
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
height = width = (
self.config.vision_config.image_size
// self.config.vision_config.patch_size
)
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
if height * width != base_image_feature.shape[0]:
raise ValueError(
"The number of patches is not consistent with the image size."
)
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat(
(
image_feature,
self.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1
),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
image_feature = torch.cat(
(image_feature, self.image_newline[None]), dim=0
)
new_image_features.append(image_feature)
image_features = torch.stack(new_image_features, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_features
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -26,12 +26,12 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -41,20 +41,12 @@ from text_generation_server.layers import (
TensorParallelMultiAdapterLinear, TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear, TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
) )
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
model_type = "mistral" model_type = "mistral"
@ -160,6 +152,7 @@ class MistralAttention(torch.nn.Module):
], ],
process_group=weights.process_group, process_group=weights.process_group,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -185,12 +178,10 @@ class MistralAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
@ -205,38 +196,36 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None: kv_cache.store(
kv_to_cache = kv[prefill_cache_indices] key=kv[:, 0],
else: value=kv[:, 1],
kv_to_cache = kv slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj( return self.o_proj(
@ -300,29 +289,11 @@ class MistralMLP(nn.Module):
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states, adapter_data): def forward(self, hidden_states, adapter_data):
if ( gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
SYSTEM == "rocm" gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
and self.hidden_act == "silu" return self.down_proj(
and hidden_states.dtype == torch.float16 self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
and hidden_states.shape[0] == 1 )
and not self.quantize
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
)
return self.down_proj(out, adapter_data)
else:
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
@ -355,12 +326,10 @@ class MistralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -371,12 +340,10 @@ class MistralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -423,20 +390,15 @@ class MistralModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, true_max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -447,12 +409,10 @@ class MistralModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s,
prefill_cache_indices,
adapter_data, adapter_data,
hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -498,35 +458,21 @@ class FlashMistralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = self.model( hidden_states = self.model(
inputs_embeds, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
true_max_s,
prefill_cache_indices,
adapter_data, adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -37,9 +37,9 @@ from text_generation_server.layers.attention import (
Seqlen, Seqlen,
attention, attention,
paged_attention, paged_attention,
reshape_and_cache, HPUPagedAttentionMetadata,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastRMSNorm from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
@ -215,6 +215,7 @@ class MixtralAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
@ -234,11 +235,9 @@ class MixtralAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -253,38 +252,36 @@ class MixtralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None: kv_cache.store(
kv_to_cache = kv[prefill_cache_indices] key=kv[:, 0],
else: value=kv[:, 1],
kv_to_cache = kv slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -378,11 +375,9 @@ class MixtralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -393,11 +388,9 @@ class MixtralLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
) )
# faster post attention rms norm # faster post attention rms norm
@ -448,20 +441,15 @@ class MixtralModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, true_max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -472,11 +460,9 @@ class MixtralModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -507,34 +493,21 @@ class FlashMixtralForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
true_max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -0,0 +1,986 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Mllama model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
FastLinear,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM,
)
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(
batch_size, max_num_tiles * target_length, 1
)
attention_mask = (
attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
min_dtype: float,
cache_position: torch.Tensor,
batch_size: int,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
min_dtype (`float`):
The minimum value representable with the dtype `dtype`.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length),
fill_value=min_dtype,
dtype=dtype,
device=device,
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(
target_length, device=device
) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = (
causal_mask.clone()
) # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = (
causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
)
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[
:, :, :, :mask_length
].masked_fill(padding_mask, min_dtype)
return causal_mask
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(
num_vision_tokens, dim=3
)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.finfo(dtype).min
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value)
.any(dim=-1)
.type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision
class MllamaVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.head_dim = config.hidden_size // config.attention_heads
self.num_heads = config.attention_heads // weights.process_group.size()
self.qkv_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_state)
query, key, value = qkv.split(
[
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
self.head_dim * self.num_heads,
],
dim=2,
)
batch_size, q_seq_len, _ = query.shape
_, kv_seq_len, _ = key.shape
query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
output = self.o_proj(attn_output)
return output
class MllamaVisionEncoderLayer(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.num_attention_heads = config.attention_heads
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.mlp = MllamaVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=1e-05
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=1e-05
)
# there used to be an if else here, no code path
if is_gated:
self.gate_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_attn"), requires_grad=False
)
self.gate_ffn = nn.Parameter(
weights.get_tensor(f"{prefix}.gate_ffn"), requires_grad=False
)
def forward(
self,
hidden_state: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
# Self Attention
residual = hidden_state
hidden_state = self.input_layernorm(hidden_state)
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
hidden_state = residual + gate_attn * hidden_state
# Feed forward
residual = hidden_state
hidden_state = self.post_attention_layernorm(hidden_state)
hidden_state = self.mlp(hidden_state)
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
hidden_state = residual + gate_ffn * hidden_state
return hidden_state
class MllamaVisionEncoder(nn.Module):
def __init__(self, *, prefix, config, weights, is_gated: bool, num_layers: int):
super().__init__()
self.config = config
self.layers = [
MllamaVisionEncoderLayer(
prefix=f"{prefix}.layers.{i}",
config=config,
weights=weights,
is_gated=is_gated,
)
for i in range(num_layers)
]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
encoder_states = [hidden_states]
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
)
hidden_states = layer_outputs
encoder_states.append(hidden_states)
return hidden_states, encoder_states
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.embedding = TensorParallelEmbedding(
prefix=f"{prefix}.embedding", weights=weights
)
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
embeddings = self.embedding(aspect_ratio_ids)
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
# Always gated.
embeddings = embeddings * self.gate.tanh()
hidden_state = hidden_state + embeddings
return hidden_state
class MllamaPrecomputedPositionEmbedding(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.max_num_tiles = config.max_num_tiles
self.max_aspect_ratio_id = config.max_aspect_ratio_id
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
self.hidden_size = config.hidden_size
self.scale = config.hidden_size**-0.5
self.gate = nn.Parameter(
weights.get_tensor(f"{prefix}.gate"), requires_grad=False
)
# position embedding
embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
)
self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
self.tile_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.tile_embedding", weights=weights
)
def forward(
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
) -> torch.Tensor:
# position embeddings
hidden_state = hidden_state + self.gated_position_embedding.view(
1, 1, self.num_patches, self.hidden_size
)
# precomputed tile position embeddings
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
batch_size = hidden_state.shape[0]
tile_position_embedding = tile_position_embedding.reshape(
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
)
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
hidden_state = hidden_state + gated_tile_position_embedding
return hidden_state
class MllamaVisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.image_size = config.image_size
self.patch_size = config.patch_size
self.max_num_tiles = config.max_num_tiles
self.hidden_size = config.hidden_size
self.num_channels = config.num_channels
self.intermediate_layers_indices = config.intermediate_layers_indices
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
self.scale = config.hidden_size**-0.5
self.dtype = weights.dtype
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
)
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
prefix=f"{prefix}.gated_positional_embedding",
config=config,
weights=weights,
)
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.pre_tile_positional_embedding",
config=config,
weights=weights,
)
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
prefix=f"{prefix}.post_tile_positional_embedding",
config=config,
weights=weights,
)
## layer norms
self.layernorm_pre = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_pre",
weights=weights,
# torch default
eps=1e-05,
)
self.layernorm_post = nn.LayerNorm.load(
prefix=f"{prefix}.layernorm_post",
weights=weights,
# torch default
eps=1e-05,
)
## encoders
self.transformer = MllamaVisionEncoder(
prefix=f"{prefix}.transformer",
config=config,
weights=weights,
is_gated=False,
num_layers=config.num_hidden_layers,
)
self.global_transformer = MllamaVisionEncoder(
prefix=f"{prefix}.global_transformer",
config=config,
weights=weights,
is_gated=True,
num_layers=config.num_global_layers,
)
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
aspect_ratio_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
(
batch_size,
num_concurrent_media,
num_tiles,
num_channels,
height,
width,
) = pixel_values.shape
pixel_values = pixel_values.reshape(
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
)
aspect_ratio_ids = aspect_ratio_ids.reshape(
batch_size * num_concurrent_media, -1
)
# patch embedding
patch_embeds = self.patch_embedding(pixel_values)
hidden_state = patch_embeds.flatten(2).transpose(1, 2)
# tile embeddings
_, num_patches, dim = hidden_state.shape
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, -1, dim
)
hidden_state = self.pre_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
# apply cls token
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media * num_tiles, num_patches, dim
)
hidden_state = self.apply_class_embedding(hidden_state)
num_patches += 1
# apply position embeddings
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media, num_tiles, num_patches, dim
)
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
# apply encoder
hidden_state = self.layernorm_pre(hidden_state)
# Compute the number of tokens to pad
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
# Compute padding tuple for pad function
padding = (
0,
0,
0,
num_padding_patches,
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
# Pad the tensor
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
slice_index = -num_padding_patches if num_padding_patches > 0 else None
if attention_mask is not None:
attention_mask = attention_mask.reshape(
batch_size * num_concurrent_media, -1
)
attention_mask = _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask=attention_mask,
num_patches=self.num_patches,
target_length=hidden_state.shape[2],
dtype=self.dtype,
)
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
hidden_state, all_intermediate_hidden_states = self.transformer(
hidden_state,
attention_mask=attention_mask,
)
intermediate_hidden_states = [
hidden_state
for idx, hidden_state in enumerate(all_intermediate_hidden_states)
if idx in self.intermediate_layers_indices
]
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
# apply global encoder
hidden_state = self.layernorm_post(hidden_state)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = self.post_tile_positional_embedding(
hidden_state, aspect_ratio_ids
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles * (num_patches + num_padding_patches),
dim,
)
hidden_state, _ = self.global_transformer(
hidden_state, attention_mask=attention_mask
)
hidden_state = hidden_state.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
dim,
)
hidden_state = hidden_state[:, :, :slice_index]
# adding intermediate layer outputs
hidden_state = hidden_state.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, dim
)
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size * num_concurrent_media,
num_tiles,
num_patches + num_padding_patches,
-1,
)
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
intermediate_hidden_states = intermediate_hidden_states.reshape(
batch_size, num_concurrent_media, num_tiles, num_patches, -1
)
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
return hidden_state
class MllamaTextCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, *, prefix, config, weights, layer_idx):
super().__init__()
self.config = config
self.num_heads = self.config.num_attention_heads
self.num_key_value_heads = self.config.num_key_value_heads
self.dropout = config.dropout
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.layer_idx = layer_idx
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
self.num_key_value_heads // weights.process_group.size()
)
self.q_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.q_proj",
weights=weights,
bias=False,
)
self.k_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.k_proj",
weights=weights,
bias=False,
)
self.v_proj = TensorParallelColumnLinear.load(
config,
prefix=f"{prefix}.v_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.q_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
)
self.k_norm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
)
self.softmax_scale = self.head_size**-0.5
def forward(
self,
hidden_states: torch.Tensor,
cross_attention_states: Optional[torch.Tensor] = None,
# past_key_value=None,
# attention_mask: Optional[torch.Tensor] = None,
# cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# hidden_states = hidden_states.unsqueeze(0)
# bsz, q_len, _ = hidden_states.size()
(
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
indices,
) = cross_attention_states
bs = cu_seqlen_q.size(0) - 1
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states)
key_states = self.k_proj(cross_attention_states)
value_states = self.v_proj(cross_attention_states)
key_states = key_states.view(bs, -1, self.num_key_value_heads, self.head_size)
value_states = value_states.view(
bs, -1, self.num_key_value_heads, self.head_size
)
key_states = self.k_norm(key_states)
# key_states = key_states.repeat(1, self.num_key_value_groups, 1)
# value_states = value_states.repeat(1, self.num_key_value_groups, 1)
causal = False
# logger.info(
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
# )
# execute sdpa
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query_states,
key_states,
value_states,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return attn_output
# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText
class MllamaTextMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
shape = x.shape
gate_up_states = self.gate_up_proj(x)
gate_up_states = gate_up_states.view(*shape[:-1], 2, self.intermediate_size)
result = self.down_proj(
self.act_fn(gate_up_states[:, 0]) * gate_up_states[:, 1]
)
return result
class FlashLlamaCrossLayer(torch.nn.Module):
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
def __init__(self, *, prefix, config, weights, index) -> None:
layer_idx = index
super().__init__()
self.cross_attn = MllamaTextCrossAttention(
prefix=f"{prefix}.cross_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.input_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.cross_attn_attn_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_attn_gate"), requires_grad=False
)
self.mlp = MllamaTextMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.post_attention_layernorm = MllamaTextRMSNorm.load(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.cross_attn_mlp_gate = torch.nn.Parameter(
weights.get_tensor(f"{prefix}.cross_attn_mlp_gate"), requires_grad=False
)
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
kv_cache,
slots,
seqlen,
adapter_data,
cross_attention_states, # [ IB, ...]
hpu_attention_meta,
) -> Tuple[torch.Tensor, torch.Tensor]:
if cross_attention_states is None:
return hidden_states, residual
if residual is not None:
hidden_states += residual
indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices]
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.cross_attn(
hidden_states=hidden_states,
# attention_mask=cross_attention_mask,
cross_attention_states=cross_attention_states,
)
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
out_hidden_states[indices] = hidden_states
hidden_states = out_hidden_states
return hidden_states, None
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText
class MllamaTextRMSNorm(nn.Module):
def __init__(self, weight, eps):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
@classmethod
def load(cls, *, prefix, weights, eps):
weight = nn.Parameter(
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
)
return cls(weight=weight, eps=eps)
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class FlashMllamaForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
config.text_config._attn_implementation = "sdpa"
self.hidden_size = config.text_config.hidden_size
self.vision_model = MllamaVisionModel(
prefix="vision_model", config=config.vision_config, weights=weights
)
self.multi_modal_projector = FastLinear.load(
prefix="multi_modal_projector", config=config, weights=weights, bias=True
)
self.text_model = FlashLlamaForCausalLM(
prefix="language_model", config=config.text_config, weights=weights
)
self.config = config
self.dtype = weights.dtype
self.device = weights.device
def vision_forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
if aspect_ratio_ids is None:
raise ValueError(
"`aspect_ratio_ids` must be provided if `pixel_values` is provided"
)
# logger.info(f"PIxel values {pixel_values.shape}")
batch_size = pixel_values.shape[0]
vision_states = self.vision_model(
pixel_values, aspect_ratio_ids, aspect_ratio_mask
)
cross_attention_states = self.multi_modal_projector(vision_states).reshape(
-1, vision_states.shape[-2], self.hidden_size
)
_, _, h = cross_attention_states.shape
cross_attention_states = cross_attention_states.view(batch_size, -1, h)
# logger.info(f"cross {cross_attention_states.shape}")
return cross_attention_states
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
indices = image_indices[:]
cross_attention_states = (
cross_attention_states,
cu_seqlen_q,
cu_seqlen_k,
indices,
)
outputs = self.text_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=lm_head_indices,
adapter_data=adapter_data,
cross_attention_states=cross_attention_states,
)
return outputs

View File

@ -29,8 +29,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -39,7 +39,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -132,6 +132,7 @@ class FlashNeoxAttention(torch.nn.Module):
head_size=self.head_size, head_size=self.head_size,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=True config, prefix=f"{prefix}.dense", weights=weights, bias=True
) )
@ -146,10 +147,9 @@ class FlashNeoxAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
@ -165,30 +165,35 @@ class FlashNeoxAttention(torch.nn.Module):
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=qkv[:, 1],
value=qkv[:, 2],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
qkv[:, 0], query=qkv[:, 0],
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1], key=qkv[:, 1],
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2], value=qkv[:, 2],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
qkv[:, 0], qkv[:, 0],
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -255,10 +260,9 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
ln1_hidden_states, _ = self.input_layernorm(hidden_states) ln1_hidden_states, _ = self.input_layernorm(hidden_states)
@ -269,10 +273,9 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
@ -293,10 +296,9 @@ class FlashNeoXLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
@ -347,18 +349,15 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin( cos, sin = self.layers[0].attention.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -369,10 +368,9 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.final_layer_norm(hidden_states, residual) hidden_states, _ = self.final_layer_norm(hidden_states, residual)
@ -401,11 +399,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -414,10 +410,9 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -19,7 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
load_vision_model, load_vision_model,
@ -69,22 +69,20 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
# Unused here # Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1. # TODO This is odd but apparently pali gemma position ids start at 1.
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
max_s += 1
position_ids += 1 position_ids += 1
if pixel_values is not None: if pixel_values is not None:
@ -106,10 +104,10 @@ class PaliGemmaForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -9,8 +9,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -19,7 +19,7 @@ from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
) )
@ -90,7 +90,7 @@ def _load_gqa(config, prefix: str, weights):
dim=0, dim=0,
) )
if config.quantize not in ["gptq", "awq", "marlin"]: if config.quantize not in ["gptq", "awq"]:
weight = weight.to(dtype=weights.dtype).to(device=weights.device) weight = weight.to(dtype=weights.dtype).to(device=weights.device)
head_size = config.hidden_size // config.num_attention_heads head_size = config.hidden_size // config.num_attention_heads
@ -139,6 +139,7 @@ class FlashPhiAttention(torch.nn.Module):
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
# in llama the dense layer is called "o_proj" and has bias=False # in llama the dense layer is called "o_proj" and has bias=False
self.dense = TensorParallelRowLinear.load( self.dense = TensorParallelRowLinear.load(
@ -159,10 +160,9 @@ class FlashPhiAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
# Compute query, key, value and split # Compute query, key, value and split
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -188,29 +188,34 @@ class FlashPhiAttention(torch.nn.Module):
) )
# Reshape key and value and cache # Reshape key and value and cache
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_scales=self.kv_scales,
block_tables, kv_cache=kv_cache,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -274,10 +279,9 @@ class FlashPhiLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
hidden_states, res = self.input_layernorm(hidden_states, residual) hidden_states, res = self.input_layernorm(hidden_states, residual)
# Self Attention # Self Attention
@ -287,10 +291,9 @@ class FlashPhiLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states = self.resid_dropout(attn_output).add( hidden_states = self.resid_dropout(attn_output).add(
@ -339,18 +342,15 @@ class FlashPhiModel(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -361,10 +361,9 @@ class FlashPhiModel(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -394,11 +393,9 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -407,10 +404,9 @@ class FlashPhiForCausalLM(torch.nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -17,7 +17,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastRMSNorm, FastRMSNorm,
@ -86,6 +86,8 @@ class Qwen2Attention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
@ -104,11 +106,9 @@ class Qwen2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
query, kv = qkv.split( query, kv = qkv.split(
@ -123,38 +123,36 @@ class Qwen2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None: kv_cache.store(
kv_to_cache = kv[prefill_cache_indices] key=kv[:, 0],
else: value=kv[:, 1],
kv_to_cache = kv slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -223,13 +221,11 @@ class Qwen2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, residual = self.input_layernorm(hidden_states)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
@ -238,21 +234,17 @@ class Qwen2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
) )
hidden_states = attn_output + residual
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(hidden_states)
attn_output, res mlp_output = self.mlp(hidden_states)
) hidden_states = mlp_output + residual
return hidden_states
mlp_output = self.mlp(normed_attn_res_output)
return mlp_output, attn_res
class Qwen2Model(torch.nn.Module): class Qwen2Model(torch.nn.Module):
@ -264,9 +256,6 @@ class Qwen2Model(torch.nn.Module):
process_group = weights.process_group process_group = weights.process_group
self.tp_rank = process_group.rank() self.tp_rank = process_group.rank()
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens", weights=weights
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2Layer( Qwen2Layer(
@ -290,42 +279,35 @@ class Qwen2Model(torch.nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, inputs_embeds: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
position_ids, true_max_s, hidden_states.dtype position_ids,
) )
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states = layer(
hidden_states, hidden_states,
residual, residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
prefill_cache_indices,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states)
return hidden_states return hidden_states
@ -346,6 +328,12 @@ class Qwen2ForCausalLM(torch.nn.Module):
prefix=f"{prefix}.{suffix}" if prefix else suffix, prefix=f"{prefix}.{suffix}" if prefix else suffix,
weights=weights, weights=weights,
) )
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
weights=weights,
)
self.max_past = config.sliding_window self.max_past = config.sliding_window
self.max_past_tensor = ( self.max_past_tensor = (
torch.tensor(config.sliding_window, device=weights.device) torch.tensor(config.sliding_window, device=weights.device)
@ -359,34 +347,23 @@ class Qwen2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None: inputs_embeds = self.embed_tokens(input_ids)
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, inputs_embeds,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
true_max_s,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -12,14 +12,14 @@ from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import FastLayerNorm from text_generation_server.layers.layernorm import FastLayerNorm
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
attention, attention,
paged_attention, paged_attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
@ -79,6 +79,7 @@ class RWConfig(PretrainedConfig):
self.alibi = False self.alibi = False
self.rotary = True self.rotary = True
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = 2048
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg # Backward compatibility with n_embed kwarg
@ -160,6 +161,7 @@ class FlashRWAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -180,10 +182,9 @@ class FlashRWAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
@ -200,30 +201,35 @@ class FlashRWAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots) kv_cache.store(
key=kv[:, 0],
value=kv[:, 1],
slots=slots,
kv_scales=self.kv_scales,
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.dense(attn_output.view(-1, self.num_heads * self.head_size)) return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
@ -278,6 +284,7 @@ class FlashRWLargeAttention(torch.nn.Module):
weights=weights, weights=weights,
bias=config.bias, bias=config.bias,
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.dense = load_row( self.dense = load_row(
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
) )
@ -293,10 +300,9 @@ class FlashRWLargeAttention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size) qkv = qkv.view(-1, self.num_groups, self.num_heads + 2, self.head_size)
@ -312,36 +318,35 @@ class FlashRWLargeAttention(torch.nn.Module):
# Inplace rotary # Inplace rotary
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
reshape_and_cache( kv_cache.store(
kv[:, :, 0].contiguous(), key=kv[:, :, 0].contiguous(),
kv[:, :, 1].contiguous(), value=kv[:, :, 1].contiguous(),
kv_cache[0], slots=slots,
kv_cache[1], kv_scales=self.kv_scales,
slots,
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(), key=kv[:, :, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(), value=kv[:, :, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.dense( return self.dense(
@ -424,10 +429,9 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
if self.parallel_attn: if self.parallel_attn:
ln_hidden_states, residual = self.input_layernorm(hidden_states, residual) ln_hidden_states, residual = self.input_layernorm(hidden_states, residual)
@ -438,10 +442,9 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
mlp_output = self.mlp(ln_hidden_states) mlp_output = self.mlp(ln_hidden_states)
@ -460,10 +463,9 @@ class FlashRWLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if self.post_attention_layernorm is not None: if self.post_attention_layernorm is not None:
@ -547,10 +549,9 @@ class FlashRWLargeLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
# Layer norm. # Layer norm.
ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual)
@ -562,10 +563,9 @@ class FlashRWLargeLayer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
# MLP. # MLP.
@ -623,18 +623,15 @@ class FlashRWModel(FlashRWPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin( cos, sin = self.h[0].self_attention.rotary_emb.get_cos_sin(position_ids)
position_ids, max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.h): for i, layer in enumerate(self.h):
@ -645,10 +642,9 @@ class FlashRWModel(FlashRWPreTrainedModel):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -675,11 +671,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -688,10 +682,9 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -8,8 +8,8 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
@ -18,7 +18,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.gptq import GPTQWeightsLoader from text_generation_server.layers.gptq import GPTQWeightsLoader
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
@ -32,10 +32,6 @@ def load_multi_mqa(
return _load_multi_mqa_gptq( return _load_multi_mqa_gptq(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
) )
elif config.quantize == "marlin":
raise RuntimeError(
"santacoder models with marlin quantization are not yet supported"
)
else: else:
return _load_multi_mqa( return _load_multi_mqa(
config, prefix, weights, bias, head_size, num_heads, hidden_size config, prefix, weights, bias, head_size, num_heads, hidden_size
@ -259,6 +255,7 @@ class FlashMQAttention(torch.nn.Module):
self.c_proj = load_row( self.c_proj = load_row(
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
) )
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.kv_head_mapping = torch.zeros( self.kv_head_mapping = torch.zeros(
self.num_heads, dtype=torch.int32, device=weights.device self.num_heads, dtype=torch.int32, device=weights.device
) )
@ -268,10 +265,9 @@ class FlashMQAttention(torch.nn.Module):
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
qkv = self.c_attn(hidden_states) qkv = self.c_attn(hidden_states)
@ -284,32 +280,35 @@ class FlashMQAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key_value = key_value.view(-1, 2, 1, self.head_size) key_value = key_value.view(-1, 2, 1, self.head_size)
reshape_and_cache( kv_cache.store(
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots key=key_value[:, 0],
value=key_value[:, 1],
slots=slots,
kv_scales=self.kv_scales,
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0], key=key_value[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1], value=key_value[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -371,20 +370,18 @@ class Block(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
): ):
hidden_states, residual = self.ln_1(hidden_states, residual) hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
hidden_states, hidden_states,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, residual = self.ln_2(hidden_states, residual) hidden_states, residual = self.ln_2(hidden_states, residual)
@ -435,10 +432,9 @@ class FlashSantacoderModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) + self.wpe(position_ids) hidden_states = self.wte(input_ids) + self.wpe(position_ids)
@ -452,10 +448,9 @@ class FlashSantacoderModel(nn.Module):
residual, residual,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
@ -484,11 +479,9 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -497,10 +490,9 @@ class FlashSantacoderForCausalLM(nn.Module):
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, hpu_attention_meta,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -29,17 +29,19 @@ from typing import Optional, List, Tuple
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
paged_attention, paged_attention,
attention, attention,
reshape_and_cache,
Seqlen, Seqlen,
HPUPagedAttentionMetadata,
) )
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
) )
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
FastLayerNorm, FastLayerNorm,
FastRMSNorm, FastRMSNorm,
@ -110,17 +112,31 @@ class Starcoder2Config(PretrainedConfig):
) )
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights, layer_id):
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
head_size = config.hidden_size // config.num_attention_heads
sizes = [
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
]
if config.num_attention_heads != config.num_key_value_heads: if config.num_attention_heads != config.num_key_value_heads:
return _load_gqa(config, prefix, weights) base_layer = _load_gqa(config, prefix, weights)
else: else:
return TensorParallelColumnLinear.load_multi( base_layer = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=prefixes,
dim=0, dim=0,
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
return TensorParallelMultiAdapterLinear.load(
base_layer=base_layer,
layer_id=layer_id,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
def _load_gqa(config, prefix: str, weights): def _load_gqa(config, prefix: str, weights):
@ -158,6 +174,7 @@ def _load_gqa(config, prefix: str, weights):
class Starcoder2Attention(torch.nn.Module): class Starcoder2Attention(torch.nn.Module):
def __init__( def __init__(
self, self,
index: int,
prefix: str, prefix: str,
config, config,
weights, weights,
@ -189,14 +206,23 @@ class Starcoder2Attention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights, index)
self.kv_scales = get_kv_scales(weights, f"{prefix}")
self.o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=getattr(config, "use_bias", False),
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -209,13 +235,12 @@ class Starcoder2Attention(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
prefill_cache_indices, hpu_attention_meta,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -228,45 +253,45 @@ class Starcoder2Attention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
if prefill_cache_indices is not None: kv_cache.store(
kv_to_cache = kv[prefill_cache_indices] key=kv[:, 0],
else: value=kv[:, 1],
kv_to_cache = kv slots=slots,
kv_scales=self.kv_scales,
reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # sdpa
attn_output = attention( attn_output = attention(
query, query=query,
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0], key=kv[:, 0],
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1], value=kv[:, 1],
seqlen, kv_cache=kv_cache,
block_tables, kv_scales=self.kv_scales,
self.softmax_scale, seqlen=seqlen,
softmax_scale=self.softmax_scale,
window_size_left=self.max_past, window_size_left=self.max_past,
) )
# Decode # Decode
else: else:
attn_output = paged_attention( attn_output = paged_attention(
query, query,
kv_cache[0], kv_cache,
kv_cache[1],
self.kv_head_mapping, self.kv_head_mapping,
self.softmax_scale, self.softmax_scale,
block_tables,
seqlen, seqlen,
max_s, kv_scales=self.kv_scales,
hpu_attention_meta=hpu_attention_meta,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Starcoder2MLP(nn.Module): class Starcoder2MLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, index):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -280,27 +305,42 @@ class Starcoder2MLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.c_fc = TensorParallelColumnLinear.load( c_fc = TensorParallelColumnLinear.load(
config, config,
prefix=f"{prefix}.c_fc", prefix=f"{prefix}.c_fc",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.c_proj = TensorParallelRowLinear.load( c_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
def forward(self, hidden_states): self.c_fc = TensorParallelMultiAdapterLinear.load(
hidden_states = self.c_fc(hidden_states) c_fc,
layer_id=index,
layer_names=[f"{prefix}.c_fc"],
sizes=[config.intermediate_size, config.intermediate_size],
process_group=weights.process_group,
)
self.c_proj = TensorParallelAdapterRowLinear.load(
c_proj,
index,
"c_proj",
process_group=weights.process_group,
)
def forward(self, hidden_states, adapter_data):
hidden_states = self.c_fc(hidden_states, adapter_data)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
return self.c_proj(hidden_states) return self.c_proj(hidden_states, adapter_data)
class Starcoder2GatedMLP(nn.Module): class Starcoder2GatedMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act act = config.hidden_act
self.act = ( self.act = (
@ -314,27 +354,47 @@ class Starcoder2GatedMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
sizes = [
config.intermediate_size,
config.intermediate_size,
]
gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=prefixes,
weights=weights, weights=weights,
dim=0, dim=0,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
layer_names=prefixes,
sizes=sizes,
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=config.use_bias, bias=config.use_bias,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
STARCODER2_NORMALIZATION_CLASSES = { STARCODER2_NORMALIZATION_CLASSES = {
@ -353,11 +413,11 @@ class Starcoder2Layer(nn.Module):
super().__init__() super().__init__()
prefix = f"model.layers.{layer_id}" prefix = f"model.layers.{layer_id}"
self.self_attn = Starcoder2Attention( self.self_attn = Starcoder2Attention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
) )
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type]( self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
prefix=f"{prefix}.mlp", config=config, weights=weights prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
) )
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load( self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
@ -379,11 +439,10 @@ class Starcoder2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
prefill_cache_indices, hpu_attention_meta,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -394,11 +453,10 @@ class Starcoder2Layer(nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
prefill_cache_indices, hpu_attention_meta,
) )
# faster post attention rms norm # faster post attention rms norm
@ -406,7 +464,7 @@ class Starcoder2Layer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
@ -447,20 +505,16 @@ class Starcoder2Model(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, adapter_data,
true_max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
# Avoid to index in each layer # Avoid to index in each layer
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin( cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
position_ids, true_max_s, hidden_states.dtype
)
residual = None residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
@ -471,11 +525,10 @@ class Starcoder2Model(torch.nn.Module):
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache[i], kv_cache[i],
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
prefill_cache_indices, hpu_attention_meta,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -519,34 +572,22 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
# Slots also need to be sliced as it has the same size as the whole kv tensor
slots = slots[prefill_cache_indices]
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
seqlen = seqlen.clamp(max=self.max_past_tensor)
hidden_states = self.model( hidden_states = self.model(
input_ids, input_ids,
position_ids, position_ids,
cu_seqlen_prefill, cu_seqlen_prefill,
kv_cache, kv_cache,
block_tables,
slots, slots,
seqlen, seqlen,
max_s, adapter_data,
true_max_s, hpu_attention_meta,
prefill_cache_indices,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -25,7 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import ( from text_generation_server.models.custom_modeling.vlm import (
load_text_model, load_text_model,
) )
from text_generation_server.layers.attention import Seqlen from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import ( from text_generation_server.layers import (
@ -728,7 +728,8 @@ class Idefics2ForConditionalGeneration(nn.Module):
): ):
"""In place merges in vision_embeddings with inputs_embeds.""" """In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index # mask = input_ids == self.config.image_token_index
mask = input_ids == self.config.image_token_id # - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots ! # Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1]) inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds return inputs_embeds
@ -739,17 +740,16 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor], cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
seqlen: Seqlen, seqlen: Seqlen,
max_s: int, hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None, pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None, pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here # Unused here
image_sizes: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
): ):
inputs_embeds = self.text_model.embed_tokens(input_ids) inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None: if pixel_values is not None:
@ -793,6 +793,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
].contiguous() ].contiguous()
patch_size = self.config.vision_config.patch_size patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold( patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size dimension=1, size=patch_size, step=patch_size
) )
@ -800,6 +801,21 @@ class Idefics2ForConditionalGeneration(nn.Module):
dimension=2, size=patch_size, step=patch_size dimension=2, size=patch_size, step=patch_size
) )
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder # Get sequence from the vision encoder
image_hidden_states = self.vision_model( image_hidden_states = self.vision_model(
@ -825,12 +841,9 @@ class Idefics2ForConditionalGeneration(nn.Module):
position_ids=position_ids, position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill, cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache, kv_cache=kv_cache,
block_tables=block_tables,
slots=slots, slots=slots,
seqlen=seqlen, seqlen=seqlen,
max_s=max_s, hpu_attention_meta=hpu_attention_meta,
true_max_s=max_s,
prefill_cache_indices=None,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -0,0 +1,596 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Idefics3 model."""
from typing import List, Optional, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Idefics3VisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
resolution.
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
which allows treating images in their native aspect ratio and without the need to resize them to the same
fixed size. In particular, we start from the original pre-trained SigLIP model
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
"""
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid",
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
)
self.patch_embedding.bias = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
)
self.num_patches_per_side = self.image_size // self.patch_size
self.num_patches = self.num_patches_per_side**2
self.num_positions = self.num_patches
self.position_embedding = TensorParallelEmbedding(
prefix=f"{prefix}.position_embedding", weights=weights
)
def forward(
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
max_nb_patches_h, max_nb_patches_w = (
max_im_h // self.patch_size,
max_im_w // self.patch_size,
)
boundaries = torch.arange(
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
)
position_ids = torch.full(
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(
fractional_coords_h, boundaries, right=True
)
bucket_coords_w = torch.bucketize(
fractional_coords_w, boundaries, right=True
)
pos_ids = (
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
position_ids = position_ids.to(self.position_embedding.weight.device)
embeddings = embeddings + self.position_embedding(position_ids)
return embeddings
class Idefics3VisionAttention(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_size = self.embed_dim // self.num_heads
if self.head_size * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_size**-0.5
self.dropout = config.attention_dropout
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=True,
)
self.out_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
)
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, q_len, _ = hidden_states.size()
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(
[
self.head_size * self.num_heads,
self.head_size * self.num_heads,
self.head_size * self.num_heads,
],
dim=2,
)
query_states = query_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
key_states = key_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
value_states = value_states.view(
batch_size, q_len, self.num_heads, self.head_size
).transpose(1, 2)
k_v_seq_len = key_states.shape[-2]
attn_weights = (
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
)
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
raise ValueError(
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
raise ValueError(
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
raise ValueError(
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class Idefics3VisionMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Idefics3EncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics3VisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
)
self.mlp = Idefics3VisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Idefics3Encoder(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[
Idefics3EncoderLayer(
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
)
for i in range(config.num_hidden_layers)
]
)
# Ignore copy
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
):
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask,
)
return hidden_states
class Idefics3VisionTransformer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
self.embeddings = Idefics3VisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.encoder = Idefics3Encoder(
prefix=f"{prefix}.encoder", config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
def forward(
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
):
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
patch_size = self.config.patch_size
patch_attention_mask = torch.ones(
(
batch_size,
pixel_values.size(2) // patch_size,
pixel_values.size(3) // patch_size,
)
)
patch_attention_mask = patch_attention_mask.to(
dtype=torch.bool, device=pixel_values.device
)
hidden_states = self.embeddings(
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
)
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
# The call to `_upad_input` in `_flash_attention_forward` is expensive
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
else:
patch_attention_mask = _prepare_4d_attention_mask(
patch_attention_mask, hidden_states.dtype
)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=patch_attention_mask,
)
last_hidden_state = encoder_outputs
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
class Idefics3SimpleMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
output_size = config.text_config.hidden_size
proj = nn.Parameter(
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
requires_grad=False,
).to(weights.dtype)
self.proj = nn.Linear(input_size, output_size, bias=False)
self.proj.weight = proj
def forward(self, x):
return self.proj(x)
class Idefics3Connector(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
self.scale_factor = config.scale_factor
def pixel_shuffle(self, x, scale_factor=2):
bsz, seq, embed_dim = x.size()
height = width = int(seq**0.5)
x = x.view(bsz, height, width, embed_dim)
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
x = x.permute(0, 2, 1, 3)
x = x.reshape(
bsz,
int(width / scale_factor),
int(height / scale_factor),
embed_dim * (scale_factor**2),
)
x = x.permute(0, 2, 1, 3)
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
return x
def forward(self, image_hidden_states):
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
image_hidden_states = self.modality_projection(image_hidden_states)
return image_hidden_states
class Idefics3ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
# since Idefics3 uses the `embed_tokens` for the final prediction
# config.text_config.tie_word_embeddings = True
vision_config = config.vision_config
self.text_model = load_text_model(
prefix="model" if not prefix else f"{prefix}.model",
config=config.text_config,
weights=weights,
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
self.vision_model = Idefics3VisionTransformer(
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
config.quantize = None
self.connector = Idefics3Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
self.config = config
self.image_token_id = config.image_token_id
self.pad_token_id = (
config.pad_token_id if config.pad_token_id is not None else -1
)
def _merge_input_ids_with_image_features(
self,
input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
image_features: torch.Tensor,
):
"""In place merges in vision_embeddings with inputs_embeds."""
# mask = input_ids == self.config.image_token_index
# - replace `==` with torch.where to fix the issue in hpu graph
mask = torch.where(input_ids == self.config.image_token_id)
# Let's pray we have enabled enough slots !
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor] = None,
pixel_values: torch.FloatTensor = None,
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
batch_size, num_images, num_channels, height, width = pixel_values.shape
all_states = []
all_pixel_values = pixel_values
all_pixel_mask = pixel_attention_mask
for i in range(batch_size):
pixel_values = all_pixel_values.to(
dtype=self.dtype
) # fp16 compatibility
pixel_values = pixel_values[i : i + 1]
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
# Remove padding images - padding images are full 0.
nb_values_per_image = pixel_values.shape[1:].numel()
real_images_inds = (pixel_values == 0.0).sum(
dim=(-1, -2, -3)
) != nb_values_per_image
pixel_values = pixel_values[real_images_inds].contiguous()
# Handle the vision attention mask
if pixel_attention_mask is None:
pixel_attention_mask = torch.ones(
size=(
pixel_values.size(0),
pixel_values.size(2),
pixel_values.size(3),
),
dtype=torch.bool,
device=pixel_values.device,
)
else:
# Remove padding images from the mask/pP p
pixel_attention_mask = all_pixel_mask[i : i + 1]
pixel_attention_mask = pixel_attention_mask.view(
1 * num_images, *pixel_attention_mask.shape[2:]
)
pixel_attention_mask = pixel_attention_mask[
real_images_inds
].contiguous()
patch_size = self.config.vision_config.patch_size
"""
patches_subgrid = pixel_attention_mask.unfold(
dimension=1, size=patch_size, step=patch_size
)
patches_subgrid = patches_subgrid.unfold(
dimension=2, size=patch_size, step=patch_size
)
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
"""
# hpu does none support unfold
conv_kernel = torch.ones(
[1, 1, patch_size, patch_size],
dtype=pixel_values.dtype,
device=pixel_values.device,
)
patches_subgrid = torch.nn.functional.conv2d(
pixel_attention_mask.unsqueeze(1).to(conv_kernel.dtype),
conv_kernel,
stride=patch_size,
).squeeze(1)
patch_attention_mask = torch.eq(
patches_subgrid, (patch_size * patch_size)
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states,
)
all_states.append(image_hidden_states)
image_hidden_states = torch.stack(all_states, dim=0)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
hidden_states = self.text_model.model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
adapter_data=adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.text_model.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -46,15 +46,9 @@ from text_generation_server.layers import (
FastLinear, FastLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.utils.import_utils import SYSTEM
from loguru import logger from loguru import logger
if SYSTEM == "cuda": dropout_layer_norm = None
import dropout_layer_norm
elif SYSTEM == "rocm":
from vllm._C import ops
else:
dropout_layer_norm = None
@dataclass @dataclass
@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if SYSTEM == "ipex": from vllm_hpu_extension.kernels import rms_norm
import intel_extension_for_pytorch as ipex
out = ipex.llm.functional.add_rms_norm( orig_shape = hidden_states.shape
residual, if residual is not None:
hidden_states, residual += hidden_states.view(residual.shape)
self.weight,
None,
self.variance_epsilon,
residual is not None,
)
return out
elif hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
elif SYSTEM == "cuda":
# faster post attention rms norm
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states
if unwrap:
normed_hidden_states = normed_hidden_states.view(*shape)
return normed_hidden_states
elif SYSTEM == "rocm":
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
if residual is not None:
hidden_states += residual
residual = hidden_states
unwrap = False
if len(hidden_states.shape) > 2:
unwrap = True
shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states)
ops.rms_norm(
out,
hidden_states,
self.weight.data,
self.variance_epsilon,
)
if unwrap:
out = out.view(*shape)
return out
else: else:
raise ValueError( residual = hidden_states
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction." # Note: HPUFusedRMSNorm requires 3D tensors as inputs
) if len(orig_shape) == 2:
residual = residual.unsqueeze(0)
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
return x.view(orig_shape), residual.view(orig_shape)
# this was adapted from LlamaMLP # this was adapted from LlamaMLP

View File

@ -196,7 +196,10 @@ class MambaModel(nn.Module):
def __init__(self, config, weights): def __init__(self, config, weights):
super().__init__() super().__init__()
prefix = "backbone" prefix = "backbone"
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) try:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
except RuntimeError:
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i) ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
@ -206,7 +209,10 @@ class MambaModel(nn.Module):
self.norm_f = FastRMSNorm.load( self.norm_f = FastRMSNorm.load(
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
) )
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) try:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
except RuntimeError:
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
self.config = config self.config = config
def forward( def forward(

View File

@ -1,796 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPTNeoX model."""
from typing import Optional, Tuple, Union
import os
import torch
import torch.distributed
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
CUSTOM_KERNELS_ENABLED = False
if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try:
from custom_kernels import fused_attention_cuda
CUSTOM_KERNELS_ENABLED = True
except ImportError:
pass
def make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.ones(
(target_length, target_length + past_key_values_length),
dtype=torch.bool,
device=device,
)
mask = mask.triu(1 + past_key_values_length)
expanded_mask = mask.unsqueeze(0).expand(
batch_size, target_length, target_length + past_key_values_length
)
return expanded_mask
def expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, tgt_length, src_length)
def prepare_attn_mask(
attention_mask: torch.Tensor,
input_shape: Tuple[int, int],
past_key_values_length: int,
) -> torch.BoolTensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
combined_attention_mask = None
device = attention_mask.device
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, tgt_length, src_length]
expanded_attn_mask = expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
class GPTNeoXPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
class GPTNeoXAttention(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
# ??? TODO
# self.register_buffer(
# "bias",
# torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
# 1, 1, max_positions, max_positions
# ),
# )
# self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims,
config.max_position_embeddings,
base=config.rotary_emb_base,
)
self.rotary_emb.inv_freq = nn.Parameter(
weights.get_tensor(f"{prefix}.rotary_emb.inv_freq")
)
self.inv_norm_factor = 1.0 / torch.sqrt(
torch.tensor(self.head_size, dtype=torch.float32)
).to(torch.get_default_dtype())
if self.num_attention_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_attention_heads` must be divisible by `num_shards` "
f"(got `num_attention_heads`: {self.num_attention_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_attention_heads = (
self.num_attention_heads // weights.process_group.size()
)
self.query_key_value = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.query_key_value", weights=weights, bias=True
)
self.dense = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense", weights=weights, bias=True
)
def forward(
self,
hidden_states,
position_ids,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape).permute(0, 2, 1, 3)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query, key, value = qkv.split(self.head_size, -1)
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
key_rot = key[..., : self.rotary_ndims]
query_rot, key_rot = self.rotary_emb(query_rot, key_rot, position_ids, seq_len)
query[..., : self.rotary_ndims] = query_rot
key[..., : self.rotary_ndims] = key_rot
if CUSTOM_KERNELS_ENABLED:
attn_output, present, attn_weights = fused_attention_cuda.forward(
query,
key,
value,
layer_past,
attention_mask,
head_mask,
self.inv_norm_factor,
self.num_attention_heads,
use_cache,
)
else:
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
# Reshape outputs
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_size
)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(
tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size
)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
query = query.reshape(
batch_size * num_attention_heads, query_length, attn_head_size
)
key = key.reshape(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
1,
dtype=query.dtype,
device=key.device,
).expand(batch_size * num_attention_heads, query_length, key_length)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=self.inv_norm_factor,
)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attn_scores.dtype
if input_dtype in [torch.float16, torch.bfloat16]:
attn_scores = attn_scores.to(torch.float)
attn_scores = torch.where(
attention_mask, torch.finfo(attn_scores.dtype).min, attn_scores
)
attn_scores = attn_scores.view(
batch_size, num_attention_heads, query_length, key_length
)
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
self.true_inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", self.true_inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
self.cos_cached = None
self.sin_cached = None
@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
@staticmethod
def _create_cos_sin(inv_freq, max_position_embeddings, dtype, device):
t = torch.arange(
max_position_embeddings, device=inv_freq.device, dtype=inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(device).to(dtype), emb.sin().to(device).to(dtype)
def forward(self, q, k, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (
seq_len > self.max_seq_len_cached
or self.cos_cached is None
or self.sin_cached is None
):
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
self.cos_cached, self.sin_cached = self._create_cos_sin(
self.true_inv_freq, self.max_seq_len_cached, q.dtype, q.device
)
return rotary_forward(q, k, self.cos_cached, self.sin_cached, position_ids)
@torch.jit.script
def rotary_forward(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
chunk_size = q.shape[-1] // 2
q1, q2 = q.split(chunk_size, -1)
q_rotated = torch.cat((-q2, q1), dim=-1)
k1, k2 = k.split(chunk_size, -1)
k_rotated = torch.cat((-k2, k1), dim=-1)
q_embed = (q * cos) + (q_rotated * sin)
k_embed = (k * cos) + (k_rotated * sin)
return q_embed, k_embed
class GPTNeoXMLP(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
self.act = (
ACT2FN[config.hidden_act]
if "gelu_fast" not in config.hidden_act
else lambda x: torch.nn.functional.gelu(x, approximate="tanh")
)
self.dense_h_to_4h = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.dense_h_to_4h", weights=weights, bias=True
)
self.dense_4h_to_h = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.dense_4h_to_h", weights=weights, bias=True
)
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class GPTNeoXLayer(nn.Module):
def __init__(self, layer_id, prefix: str, config, weights):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.layers.{layer_id}.input_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.post_attention_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
self.attention = GPTNeoXAttention(
config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights
)
self.mlp = GPTNeoXMLP(
config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights
)
def forward(
self,
hidden_states,
position_ids,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[
0
] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
if use_cache:
outputs = (
hidden_states,
) + outputs # hidden_states, present, (attn_weights)
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
return outputs
class GPTNeoXModel(GPTNeoXPreTrainedModel):
def __init__(self, prefix: str, config, weights):
super().__init__(config)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.embed_in = TensorParallelEmbedding(
prefix=f"{prefix}.embed_in", weights=weights
)
self.layers = nn.ModuleList(
[
GPTNeoXLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm",
weights=weights,
eps=config.layer_norm_eps,
)
self.tp_world_size = weights.process_group.size()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids=None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_length, seq_length + past_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
# Attention mask.
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[-1]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), device=hidden_states.device
)
else:
attention_mask = attention_mask.to(hidden_states.device)
causal_mask = prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
assert self.num_attention_heads % self.tp_world_size == 0
block_size = self.num_attention_heads // self.tp_world_size
causal_mask = torch.repeat_interleave(causal_mask, block_size, dim=0)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = layer(
hidden_states,
position_ids=position_ids,
attention_mask=causal_mask,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_attentions]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, prefix: str, config, weights):
super().__init__(config)
if not prefix:
prefix = "gpt_neox"
else:
prefix = f"{prefix}.gpt_neox"
self.gpt_neox = GPTNeoXModel(prefix, config, weights)
self.embed_out = SpeculativeHead.load(
config, prefix="embed_out", weights=weights
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
lm_logits, speculative_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return (
CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past[:2]
)
+ layer_past[2:],
)
return reordered_past

View File

@ -1,857 +0,0 @@
# coding=utf-8
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OPT model."""
import random
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.layers import (
FastLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
SpeculativeHead,
)
EPS = 1e-5
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full(
(tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device,
)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class OPTLearnedPositionalEmbedding(nn.Module):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, prefix: str, weights):
super().__init__()
self.offset = 2
self.weight = nn.Parameter(
weights.get_tensor(
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
)
)
def forward(
self, attention_mask: torch.LongTensor, past_key_values_length: int = 0
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (
torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask
).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return torch.nn.functional.embedding(positions + self.offset, self.weight)
class OPTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
config,
prefix,
weights,
is_decoder: bool = False,
bias: bool = True,
process_group=None,
):
super().__init__()
hidden_size = config.hidden_size
num_heads = config.num_attention_heads
self.hidden_size = hidden_size
self.num_heads = num_heads
self.dropout = config.dropout
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
process_group = weights.process_group
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // process_group.size()
self.hidden_size = self.hidden_size // process_group.size()
self.q_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.q_proj", weights=weights, bias=bias
)
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=bias
)
self.v_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.v_proj", weights=weights, bias=bias
)
self.out_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.out_proj", weights=weights, bias=bias
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
if attn_weights.dtype == torch.float16:
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(torch.float16)
else:
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if layer_head_mask is not None:
if layer_head_mask.size() != (self.num_heads,):
raise ValueError(
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
f" {layer_head_mask.size()}"
)
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if output_attentions:
# this operation is a bit awkward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to be reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.transpose(1, 2)
# Use the `hidden_size` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.hidden_size)
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights_reshaped, past_key_value
class OPTDecoderLayer(nn.Module):
def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights):
super().__init__()
self.process_group = weights.process_group
self.hidden_size = config.hidden_size
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
self.self_attn = OPTAttention(
config,
prefix=f"{prefix}.self_attn",
weights=weights,
is_decoder=True,
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.self_attn_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.self_attn_layer_norm", weights=weights, eps=EPS
)
self.fc1 = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.fc1", weights=weights, bias=config.enable_bias
)
self.fc2 = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.fc2", weights=weights, bias=config.enable_bias
)
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.final_layer_norm", weights=weights, eps=EPS
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.self_attn_layer_norm(hidden_states)
# Fully Connected
hidden_states_shape = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
residual = hidden_states
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
if self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
hidden_states = (residual + hidden_states).view(hidden_states_shape)
# 350m applies layer norm AFTER attention
if not self.do_layer_norm_before:
hidden_states = self.final_layer_norm(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class OPTPreTrainedModel(PreTrainedModel):
config_class = OPTConfig
class OPTDecoder(OPTPreTrainedModel):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
prefix = prefix + "." if prefix else ""
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}decoder.embed_tokens", weights=weights
)
self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = FastLinear.load(
config,
prefix=f"{prefix}decoder.project_out",
weights=weights,
bias=False,
)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = FastLinear.load(
config,
prefix=f"{prefix}decoder.project_in",
weights=weights,
bias=False,
)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}decoder.final_layer_norm", weights=weights, eps=EPS
)
else:
self.final_layer_norm = None
self.layers = nn.ModuleList(
[
OPTDecoderLayer(layer_id, prefix, config, weights)
for layer_id in range(config.num_hidden_layers)
]
)
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
past_key_values_length = (
past_key_values[0][0].shape[2] if past_key_values is not None else 0
)
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values_length + seq_length
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
batch_size, mask_seq_length, device=inputs_embeds.device
)
causal_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(OPTPreTrainedModel):
def __init__(self, prefix: str, config: OPTConfig, weights):
super().__init__(config)
self.decoder = OPTDecoder(prefix, config, weights)
# Initialize weights and apply final processing
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if not return_dict:
return decoder_outputs
return BaseModelOutputWithPast(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
hidden_states=decoder_outputs.hidden_states,
attentions=decoder_outputs.attentions,
)
class OPTForCausalLM(OPTPreTrainedModel):
def __init__(self, prefix, config, weights):
super().__init__(config)
self.model = OPTModel(prefix, config, weights)
self.lm_head = SpeculativeHead.load(
config,
prefix=f"{prefix + '.' if prefix else ''}decoder.embed_tokens",
weights=weights,
)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits, speculative_logits = self.lm_head(outputs.last_hidden_state)
loss = None
return (
CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
),
speculative_logits,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past

View File

@ -1,336 +0,0 @@
# imlementation of the PhiModel and PhiForCausalLM classes
import torch
import torch.distributed
import math
from torch import nn
from typing import Optional, List, Tuple
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
SpeculativeHead,
FastLinear,
)
# PhiConfig is the configuration class for the PhiModel.
class PhiConfig(PretrainedConfig):
def __init__(
self,
vocab_size=51200,
n_positions=2048,
n_embd=2560,
n_layer=32,
n_inner=None,
n_head=32,
rotary_dim=32,
layer_norm_epsilon=1e-5,
tie_word_embeddings=False,
pad_vocab_size_multiple=64,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
no_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_inner = n_inner
self.n_head = n_head
self.rotary_dim = rotary_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.tie_word_embeddings = tie_word_embeddings
self.pad_vocab_size_multiple = pad_vocab_size_multiple
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.no_bias = no_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
# RotaryEmbedding is a class that implements the rotary embedding.
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
inv_freq = [1.0 / 10000.0 ** (i / dim) for i in range(0, dim, 2)]
inv_freq_len = len(inv_freq)
inv_freq = torch.tensor(inv_freq).view(1, inv_freq_len)
t = torch.arange(0, max_seq_len, dtype=torch.float).view(max_seq_len, 1)
freqs = t.matmul(inv_freq)
self.sin = freqs.sin()
self.cos = freqs.cos()
def apply_rotary_emb_qkv(self, qkv, seqlen_offset):
b_size, seqlen, three, _, _headdim = qkv.shape
if three != 3:
raise Exception("unexpected shape for qkv")
_, rotary_dim = self.cos.shape
rotary_dim = rotary_dim * 2
q_rot = qkv[:, :, 0, :, :rotary_dim]
q_pass = qkv[:, :, 0, :, rotary_dim:]
k_rot = qkv[:, :, 1, :, :rotary_dim]
k_pass = qkv[:, :, 1, :, rotary_dim:]
q12 = torch.chunk(q_rot, 2, dim=-1)
k12 = torch.chunk(k_rot, 2, dim=-1)
q1, q2 = q12[0], q12[1]
k1, k2 = k12[0], k12[1]
c = self.cos.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
s = self.sin.narrow(0, seqlen_offset, seqlen).unsqueeze(1)
q_rot = torch.cat(
[
q1 * c - q2 * s,
q1 * s + q2 * c,
],
dim=-1,
)
k_rot = torch.cat(
[
k1 * c - k2 * s,
k1 * s + k2 * c,
],
dim=-1,
)
q = torch.cat([q_rot, q_pass], dim=-1)
k = torch.cat([k_rot, k_pass], dim=-1)
v = qkv[:, :, 2]
return q, k, v
# PhiCausalLMHead is the head of the PhiModel. It is a linear layer with a layer norm.
class PhiCausalLMHead(nn.Module):
def __init__(self, config, weights):
super().__init__()
self.ln = nn.LayerNorm.load(
prefix="lm_head.ln",
weights=weights,
eps=config.layer_norm_epsilon,
)
self.linear = SpeculativeHead.load(
config=config, prefix="lm_head.linear", weights=weights
)
def forward(self, hidden_states):
hidden_states = self.ln(hidden_states)
hidden_states = self.linear(hidden_states)
return hidden_states
# PhiMHA is a multi-head attention layer. This layer uses an attention mask to prevent tokens from attending to subsequent tokens.
class PhiMHA(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.Wqkv = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.Wqkv", weights=weights, bias=not config.no_bias
)
self.out_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.out_proj",
weights=weights,
bias=not config.no_bias,
)
self.op_size = config.n_embd
self.head_dim = int(config.n_embd / config.n_head)
self.num_heads = config.n_head
self.rotary_emb = RotaryEmbedding(
config.rotary_dim,
config.n_positions,
)
self.softmax_scale = 1.0 / math.sqrt(self.head_dim)
def forward(
self,
hidden_states,
past_kv_cache,
attention_mask=None,
):
b_size, seq_len, _n_embd = hidden_states.shape
qkv = self.Wqkv(hidden_states)
qkv = qkv.view(b_size, seq_len, 3, self.num_heads, self.head_dim)
seqlen_offset = 0 if past_kv_cache is None else past_kv_cache[0].shape[1]
q, k, v = self.rotary_emb.apply_rotary_emb_qkv(qkv, seqlen_offset)
# if there is a kv_cache, then we need to concatenate
if past_kv_cache is not None:
prev_k, prev_v = past_kv_cache
k = torch.cat([prev_k, k], dim=1)
v = torch.cat([prev_v, v], dim=1)
past_kv_cache = [k, v]
attn_weights = torch.einsum("bthd,bshd->bhts", q, k * self.softmax_scale)
if attention_mask is not None:
seqlen_k = k.shape[1]
seqlen_q = q.shape[1]
causal_mask = torch.triu(
torch.full((seqlen_q, seqlen_k), -10000.0, device=attn_weights.device),
1,
)
attn_weights = attn_weights + causal_mask.to(dtype=attn_weights.dtype)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = attn_weights.matmul(v.transpose(1, 2)).squeeze(0)
attn_output = (
attn_output.view((b_size, self.num_heads, seq_len, self.head_dim))
.transpose(1, 2)
.flatten(-2)
)
return self.out_proj(attn_output), past_kv_cache
# PhiMLP is a multi-layer perceptron. It contains two linear layers with a gelu activation function.
class PhiMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.n_inner = config.n_inner
self.fc1 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc1",
weights=weights,
bias=False,
)
self.fc2 = FastLinear.load(
config=config,
prefix=f"{prefix}.fc2",
weights=weights,
bias=False,
)
self.activation = torch.nn.functional.gelu
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# PhiBlock is a single transformer block. It contains a layer norm, a multi-head attention layer and an multi-layer perceptron.
class PhiBlock(nn.Module):
def __init__(self, layer_id, config, weights):
super().__init__()
self.layer_id = layer_id
self.layer_norm = nn.LayerNorm.load(
prefix=f"{layer_id}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = PhiMHA(prefix=f"{layer_id}.mixer", config=config, weights=weights)
self.mlp = PhiMLP(prefix=f"{layer_id}.mlp", config=config, weights=weights)
def forward(
self,
hidden_states,
kv_cache,
attention_mask,
):
residual = hidden_states
hidden_states = self.layer_norm(hidden_states)
attn_outputs, past_kv_cache = self.mixer(
hidden_states, kv_cache, attention_mask
)
feed_forward_hidden_states = self.mlp(hidden_states)
out = attn_outputs + feed_forward_hidden_states + residual
return out, past_kv_cache
# PhiModel implements the embedding layer and the transformer blocks.
class PhiModel(nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
self.tp_rank = weights.process_group.rank()
self.tp_world_size = weights.process_group.size()
self.embed_tokens = TensorParallelEmbedding(
prefix=f"{prefix}.embd.wte", weights=weights
)
self.blocks = nn.ModuleList(
[
PhiBlock(f"{prefix}.h.{layer_id}", config, weights)
for layer_id in range(config.n_layer)
]
)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
hidden_states = self.embed_tokens(input_ids)
seq_len = hidden_states.shape[1]
mask = None if seq_len <= 1 else attention_mask
past_key_values = (
[None] * len(self.blocks) if past_key_values is None else past_key_values
)
for index, block in enumerate(self.blocks):
hidden_states, new_key_values = block(
hidden_states, past_key_values[index], mask
)
past_key_values[index] = new_key_values
return hidden_states, past_key_values
# PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object.
class PhiForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()
if not prefix:
prefix = "transformer"
else:
prefix = f"{prefix}.transformer"
self.model = PhiModel(prefix, config, weights)
self.lm_head = PhiCausalLMHead(config, weights)
def forward(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
attention_mask: Optional[torch.ByteTensor] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
model_output = self.model(
input_ids, past_key_values, attention_mask, return_dict, use_cache
)
logits = self.lm_head(model_output[0])
loss = None
if labels is not None:
loss = nn.CrossEntropyLoss()(
logits[:, :-1].view(-1, logits.size(-1)), labels[:, 1:].view(-1)
)
if not return_dict:
return (
((loss,) + (logits,) + model_output[1:])
if loss is not None
else (logits,) + model_output[1:]
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=model_output[1],
hidden_states=None,
attentions=None,
)

View File

@ -0,0 +1,946 @@
# coding=utf-8
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2.5 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import numpy as np
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
from typing import Union
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, VideoInput
from transformers.processing_utils import (
ProcessingKwargs,
ProcessorMixin,
Unpack,
VideosKwargs,
)
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
fps: Union[List[float], float]
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
_defaults = {
"text_kwargs": {
"padding": False,
},
"videos_kwargs": {"fps": 2.0},
}
class Qwen2_5_VLProcessor(ProcessorMixin):
r"""
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
Args:
image_processor ([`Qwen2VLImageProcessor`], *optional*):
The image processor is a required input.
tokenizer ([`Qwen2TokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""
attributes = ["image_processor", "tokenizer"]
valid_kwargs = ["chat_template"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
def __init__(
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
):
self.image_token = (
"<|image_pad|>"
if not hasattr(tokenizer, "image_token")
else tokenizer.image_token
)
self.video_token = (
"<|video_pad|>"
if not hasattr(tokenizer, "video_token")
else tokenizer.video_token
)
super().__init__(image_processor, tokenizer, chat_template=chat_template)
def __call__(
self,
images: ImageInput = None,
text: Union[
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
] = None,
videos: VideoInput = None,
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
"""
output_kwargs = self._merge_kwargs(
Qwen2_5_VLProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(
images=images, videos=None, **output_kwargs["images_kwargs"]
)
image_grid_thw = image_inputs["image_grid_thw"]
else:
image_inputs = {}
image_grid_thw = None
if videos is not None:
videos_inputs = self.image_processor(
images=None, videos=videos, **output_kwargs["images_kwargs"]
)
video_grid_thw = videos_inputs["video_grid_thw"]
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
if isinstance(fps, (int, float)):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / fps
] * len(video_grid_thw)
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
second_per_grid_ts = [
self.image_processor.temporal_patch_size / tmp for tmp in fps
]
else:
raise ValueError(
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
)
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
else:
videos_inputs = {}
video_grid_thw = None
if not isinstance(text, list):
text = [text]
if image_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.image_token in text[i]:
text[i] = text[i].replace(
self.image_token,
"<|placeholder|>"
* (image_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.image_token)
if video_grid_thw is not None:
merge_length = self.image_processor.merge_size**2
index = 0
for i in range(len(text)):
while self.video_token in text[i]:
text[i] = text[i].replace(
self.video_token,
"<|placeholder|>"
* (video_grid_thw[index].prod() // merge_length),
1,
)
index += 1
text[i] = text[i].replace("<|placeholder|>", self.video_token)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
refer to the docstring of this method for more information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
the docstring of this method for more information.
"""
return self.tokenizer.decode(*args, **kwargs)
def post_process_image_text_to_text(self, generated_outputs):
"""
Post-process the output of the model to decode the text.
Args:
generated_outputs (`torch.Tensor` or `np.ndarray`):
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
or `(sequence_length,)`.
Returns:
`List[str]`: The decoded text.
"""
return self.tokenizer.batch_decode(
generated_outputs,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
@property
def model_input_names(self):
tokenizer_input_names = self.tokenizer.model_input_names
image_processor_input_names = self.image_processor.model_input_names
names_from_processor = list(
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
)
return names_from_processor + ["second_per_grid_ts"]
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
class Qwen2_5_VLVisionConfig(PretrainedConfig):
model_type = "qwen2_5_vl"
base_config_key = "vision_config"
def __init__(
self,
depth=32,
hidden_size=3584,
hidden_act="silu",
intermediate_size=3420,
num_heads=16,
in_channels=3,
patch_size=14,
spatial_merge_size=2,
spatial_patch_size=14,
temporal_patch_size=2,
tokens_per_second=4,
window_size=112,
out_hidden_size=3584,
fullatt_block_indexes=[7, 15, 23, 31],
**kwargs,
):
super().__init__(**kwargs)
self.depth = depth
self.hidden_size = hidden_size
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.num_heads = num_heads
self.in_channels = in_channels
self.patch_size = patch_size
self.spatial_patch_size = spatial_patch_size
self.spatial_merge_size = spatial_merge_size
self.temporal_patch_size = temporal_patch_size
self.tokens_per_second = tokens_per_second
self.window_size = window_size
self.fullatt_block_indexes = fullatt_block_indexes
self.out_hidden_size = out_hidden_size
class Qwen2_5_VLConfig(PretrainedConfig):
def __init__(
self,
vocab_size=152064,
hidden_size=8192,
intermediate_size=29568,
num_hidden_layers=80,
num_attention_heads=64,
num_key_value_heads=8,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-05,
use_cache=True,
tie_word_embeddings=False,
rope_theta=1000000.0,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=80,
attention_dropout=0.0,
vision_config=None,
rope_scaling=None,
**kwargs,
):
if vision_config is not None:
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
# TODO: @raushan update config in the hub
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2_5VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2_5VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
self.up = TensorParallelColumnLinear.load(
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
)
self.gate = TensorParallelColumnLinear.load(
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
)
self.down = TensorParallelRowLinear.load(
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_states = self.gate(hidden_states)
up_states = self.up(hidden_states)
activated_states = self.activation_fn(gate_states) * up_states
down_states = self.down(activated_states)
return down_states
class Qwen2_5VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2_5VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastRMSNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastRMSNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2_5VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, _ = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = hidden_states + attn_out
norm2_out, _ = self.norm2(hidden_states)
mlp_out = self.mlp(norm2_out)
hidden_states = hidden_states + mlp_out
return hidden_states
class Qwen2_5VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastRMSNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2_5VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_channels,
out_channels=config.hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.hidden_size // config.num_heads
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2_5VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2_5VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
# import ipdb; ipdb.set_trace()
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.hidden_size
self.window_size = config.window_size
self.patch_size = config.patch_size
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
self.fullatt_block_indexes = config.fullatt_block_indexes
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
grid_t, llm_grid_h, llm_grid_w
)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = (
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
)
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
seq_len = hidden_states.shape[0]
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
og_shape = (seq_len, -1)
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
og_shape
)
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device="cpu",
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens).to(
hidden_states.device
)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for layer_num, block in enumerate(self.blocks):
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
# that are applied at specific layers.
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
class Qwen2_5VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
if (
hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and config.rope_scaling.get("type", None) == "default"
):
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2_5VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
# import ipdb
# ipdb.set_trace()
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# Unused in this model
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -0,0 +1,519 @@
# coding=utf-8
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Qwen2 VL model."""
from typing import Optional, Tuple, List
import torch
import torch.utils.checkpoint
from torch import nn
from habana_frameworks.torch.hpex.kernels import FusedSDPA
from vllm_hpu_extension.utils import ModuleFusedSDPA
import numpy as np
from transformers.activations import ACT2FN
import torch.nn.functional as F
from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelRowLinear,
TensorParallelEmbedding,
SpeculativeHead,
)
from text_generation_server.layers.attention import (
Seqlen,
HPUPagedAttentionMetadata,
)
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
Qwen2Model,
)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class Qwen2VLAttention(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.embed_dim = config.embed_dim // weights.process_group.size()
self.head_dim = config.hidden_size // config.num_heads
self.num_heads = config.num_heads // weights.process_group.size()
self.qkv = TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv",
weights=weights,
bias=False,
num_heads=self.num_heads,
num_key_value_heads=self.num_heads,
)
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
self.proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.proj",
weights=weights,
bias=True,
)
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
def forward(
self,
hidden_state: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: int,
) -> torch.Tensor:
# apply the qkv linear layer to the hidden state
qkv = self.qkv(hidden_state)
query, key, value = qkv.split(
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
)
# reshape the query, key, and value tensors
_shape = (
hidden_state.shape[0],
self.num_heads,
self.embed_dim // self.num_heads,
)
query = query.view(*_shape)
key = key.view(*_shape)
value = value.view(*_shape)
# apply rotary positional embeddings
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
0
)
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
# calc maximum sequence length for any batch
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
causal = False
# execute sdpa
query = query.unsqueeze(0).transpose(1, 2)
key = key.unsqueeze(0).transpose(1, 2)
value = value.unsqueeze(0).transpose(1, 2)
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
attn_output = fsdpa_op(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=causal,
scale=None,
softmax_mode="None",
recompute_mode=None,
valid_sequence_lengths=None,
)
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
# reshape output to original dimensions
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
attn_output = self.proj(attn_output)
return attn_output
class Qwen2VLVisionMLP(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VLVisionBlock(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.attn = Qwen2VLAttention(
prefix=f"{prefix}.attn",
config=config,
weights=weights,
)
self.norm1 = FastLayerNorm.load(
prefix=f"{prefix}.norm1",
weights=weights,
eps=1e-6,
)
self.norm2 = FastLayerNorm.load(
prefix=f"{prefix}.norm2",
weights=weights,
eps=1e-6,
)
self.mlp = Qwen2VLVisionMLP(
prefix=f"{prefix}.mlp",
config=config,
weights=weights,
)
def forward(
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
) -> torch.Tensor:
norm1_out, residual = self.norm1(hidden_states)
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
hidden_states = attn_out + residual
norm2_out, residual = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(norm2_out)
return hidden_states
class Qwen2VLPatchMerger(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
self.patch_merger_ln_q = FastLayerNorm.load(
prefix=f"{prefix}.ln_q",
weights=weights,
eps=1e-6,
)
self.fc1 = TensorParallelColumnLinear.load(
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
)
self.fc2 = TensorParallelRowLinear.load(
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
)
def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
hidden_states = hidden_states.view(-1, self.hidden_size)
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class Qwen2VisionModel(nn.Module):
def __init__(self, *, prefix, config, weights):
super().__init__()
self.spatial_merge_size = config.spatial_merge_size
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
self.patch_embedding = nn.Conv3d(
in_channels=config.in_chans,
out_channels=config.embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
)
self.patch_embedding.weight = nn.Parameter(
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
)
head_dim = config.embed_dim // config.num_heads
# TODO: replace with static positional embeddings once implemented
theta = 10000.0
dim = head_dim // 2
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.blocks = nn.ModuleList(
[
Qwen2VLVisionBlock(
prefix=f"{prefix}.blocks.{i}",
config=config,
weights=weights,
)
for i in range(config.depth)
]
)
self.merger = Qwen2VLPatchMerger(
prefix=f"{prefix}.merger",
config=config,
weights=weights,
)
self.temporal_patch_size = config.temporal_patch_size
self.spatial_patch_size = config.spatial_patch_size
self.in_channels = config.in_channels
self.embed_dim = config.embed_dim
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
batch_size, _, hidden_size = hidden_state.shape
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
return hidden_state
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
# reshape the input tensor for processing
shape = (
-1,
self.in_channels,
self.temporal_patch_size,
self.spatial_patch_size,
self.spatial_patch_size,
)
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
# TODO: revisit to see if we can avoid some of these reshapes
# find the position ids for the input tensor based on the grid_thw
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
# apply the positional embeddings to the position ids
seq = torch.arange(
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
)
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
# create a cu_seqlens tensor to be used in the attention mask
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
# iterately apply the blocks to the hidden states
for block in self.blocks:
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
# apply the final patch merger to the hidden states
hidden_states = self.merger(hidden_states)
return hidden_states
class Qwen2VLForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.config = config
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
if (
hasattr(config, "rope_scaling")
and config.rope_scaling is not None
and config.rope_scaling.get("type", None) == "default"
):
config.rope_scaling.update({"rope_type": "mrope"})
self.hidden_size = config.hidden_size
self.vision_start_token_id = config.vision_start_token_id
self.vision_end_token_id = config.vision_end_token_id
self.image_token_id = config.image_token_id
self.video_token_id = config.video_token_id
self.spatial_merge_size = config.vision_config.spatial_merge_size
self.embed_tokens = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.visual = Qwen2VisionModel(
prefix="visual", config=config.vision_config, weights=weights
)
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
suffix = "lm_head"
self.lm_head = SpeculativeHead.load(
config,
prefix=suffix if not prefix else f"{prefix}.{suffix}",
weights=weights,
)
self.norm = FastRMSNorm.load(
prefix="model.norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.device = weights.device
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
# modified to first find segments then initialize position ids for each segment
# Steps:
# locate all vision and text segments
# calculate `vision_segment_lengths` for each vision segment to be use as offset
# calculate `text_segment_lengths` for each text segment to be used as offset
# create position ids for each vision segment based on the image grid
# create position ids for each text segment
# combine all the position ids
# the final segment is the difference between the last vision segment and the end of the input
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
def get_position_ids(
self,
input_ids: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if image_grid_thw is None:
return (
torch.arange(input_ids.shape[0], device=input_ids.device)
.unsqueeze(1)
.repeat(1, 3)
)
spatial_merge_size = self.spatial_merge_size
vision_start_token_id = self.vision_start_token_id
vision_end_token_id = self.vision_end_token_id
device = input_ids.device
dtype = input_ids.dtype
input_ids_len = input_ids.shape[0]
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
prev_vision_end = torch.cat(
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
)
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
vision_widths_max = torch.cat(
[
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
image_grid_thw[:-1, 2] // spatial_merge_size,
]
)
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
# create position ids for each vision segment based on the image grid
llm_pos_ids_list = []
for i, _ in enumerate(vision_segments):
t, h, w = (
image_grid_thw[i][0],
image_grid_thw[i][1] // spatial_merge_size,
image_grid_thw[i][2] // spatial_merge_size,
)
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
w_indices = torch.arange(w, device=device).repeat(t * h)
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
# offset by the position of the last vision segment
im = image_position_ids + vision_segment_lengths[i]
llm_pos_ids_list.append(im)
# create position ids for each text segment
text_ranges = [
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
+ text_segment_lengths[i]
for i, seq_len in enumerate(text_lengths_between_vision)
]
full_llm_pos_ids_list = [
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
]
max_s = full_llm_pos_ids_list[-1].max() + 1
final_text_len = input_ids_len - vision_ends[-1]
if final_text_len > 0:
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
full_llm_pos_ids_list.append(m + max_s)
position_ids = (
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
)
return position_ids
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlen_prefill: Optional[torch.Tensor],
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
slots: torch.Tensor,
seqlen: Seqlen,
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor],
pixel_values: torch.FloatTensor = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None,
):
inputs_embeds = self.embed_tokens(input_ids)
# apply the visual model to the pixel values if they are provided
if pixel_values is not None and len(pixel_values) > 0:
if pixel_values is not None:
image_embeds = self.visual(
pixel_values, grid_thw=image_grid_thw
).squeeze(0)
mask = torch.where(input_ids == self.image_token_id)
inputs_embeds[mask] = image_embeds
hidden_states = self.text_model(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=seqlen,
hpu_attention_meta=hpu_attention_meta,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits, speculative_logits = self.lm_head(hidden_states)
return logits, speculative_logits

View File

@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
) )
return FlashLlamaForCausalLM(prefix, config, weights) return FlashLlamaForCausalLM(prefix, config, weights, name=name)
elif config.model_type == "mistral": elif config.model_type == "mistral":
from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
FlashMistralForCausalLM, FlashMistralForCausalLM,
@ -16,7 +16,13 @@ def load_text_model(prefix, config, weights, name=None):
FlashGemmaForCausalLM, FlashGemmaForCausalLM,
) )
return FlashGemmaForCausalLM(prefix, config, weights, causal=False) return FlashGemmaForCausalLM(prefix, config, weights)
elif config.model_type == "gemma2":
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
FlashGemma2ForCausalLM,
)
return FlashGemma2ForCausalLM(prefix, config, weights)
elif config.model_type == "paligemma": elif config.model_type == "paligemma":
from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
FlashGemmaForCausalLM, FlashGemmaForCausalLM,

View File

@ -0,0 +1,489 @@
import torch
from PIL import Image
from io import BytesIO
from opentelemetry import trace
from typing import Iterable, Optional, Tuple, List, Type, Dict
from transformers import PreTrainedTokenizerBase
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch,
FlashCausalLM,
)
from text_generation_server.models.globals import PREFIX_CACHING
from loguru import logger
from text_generation_server.utils.log import log_master
from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
tracer = trace.get_tracer(__name__)
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
IDEFICS2_IMAGE_TOKEN = "<image>"
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
def _prompt_split_image(
*,
image_seq_len: int,
image_rows: int,
image_cols: int,
fake_token_around_image: str,
image_token: str,
global_img_token: str,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"
text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (`tuple`):
The size of the input image in the format (height, width).
grid_pinpoints (`List`):
A list containing possible resolutions. Each item in the list should be a tuple or list
of the form `(height, width)`.
patch_size (`int`):
The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if not isinstance(grid_pinpoints, list):
raise ValueError("grid_pinpoints should be a list of tuples or lists")
height, width = select_best_resolution(image_size, grid_pinpoints)
return height // patch_size, width // patch_size
def image_text_replacement(processor, image_input, config, image_id: int) -> str:
if config.model_type == "idefics2":
image_seq_len = 64
image_str = f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_IMAGE_TOKEN * image_seq_len}{IDEFICS2_FAKE_TOKEN}"
if processor.image_processor.do_image_splitting:
image_str *= 5
return image_str
if config.model_type == "idefics3":
# TODO: implement this in a more general way
n_rows = image_input["rows"][0][image_id]
n_cols = image_input["cols"][0][image_id]
image_seq_len = int(
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
/ (config.scale_factor**2)
)
image_str = _prompt_split_image(
image_seq_len=image_seq_len,
image_rows=n_rows,
image_cols=n_cols,
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
image_token=IDEFICS3_IMAGE_TOKEN,
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
)
return image_str
elif config.model_type == "llava_next":
height, width = image_input["image_sizes"][image_id]
num_features = get_number_of_features(height, width, config)
log_master(
logger.info,
f"Found {num_features} features in image of resolution {height}x{width}",
)
return "<image>" * num_features
elif config.model_type == "paligemma":
return "<image>" * config.text_config.num_image_tokens
elif config.model_type == "qwen2_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "qwen2_5_vl":
grid_t, grid_h, grid_w = image_input["image_grid_thw"][image_id]
num_pads = grid_t * grid_h * grid_w // 4
padding = "<|image_pad|>" * num_pads
return f"<|vision_start|>{padding}<|vision_end|>"
elif config.model_type == "gemma3":
# TODO: get correct number of features via reviewing the Gemma3 architecture
# and calculating the number of image tokens
num_pads = 256
padding = "<image_soft_token>" * num_pads
return f"\n\n<start_of_image>{padding}<end_of_image>\n\n"
else:
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
def image_text_replacement_fixup(config, text: str) -> str:
if config.model_type == "idefics2":
return text.replace(
f"{IDEFICS2_FAKE_TOKEN}{IDEFICS2_FAKE_TOKEN}", IDEFICS2_FAKE_TOKEN
)
return text
def get_unpadded_features(
original_height: int,
original_width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
aspect_ratio: float = original_width / original_height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (original_height * current_width) // original_width
padding = (current_height - new_height) // 2
current_height = current_height - (2 * padding)
else:
new_width = (original_width * current_height) // original_height
padding = (current_width - new_width) // 2
current_width = current_width - (2 * padding)
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def get_number_of_features(height: int, width: int, config) -> int:
# From config
# Hardcoded for CLIP for now
# image_grid_pinpoints = [[336, 672], [672, 336], [672, 672], [1008, 336], [336, 1008]]
image_grid_pinpoints = config.image_grid_pinpoints
image_size = config.vision_config.image_size
patch_size = config.vision_config.patch_size
assert image_size % patch_size == 0
npatches = image_size // patch_size
# Dimensions are intentionally swapped to be bug-compatible with
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
[height, width],
image_grid_pinpoints,
image_size,
)
unpadded_features, newline_features = get_unpadded_features(
height, width, npatches, num_patch_height, num_patch_width
)
# The base patch covers the entire image
base_features = npatches**2
return unpadded_features + newline_features + base_features
class FlashVlmCausalLMBatch(FlashCausalLMBatch):
pixel_values: Optional[List[torch.Tensor]]
pixel_attention_mask: Optional[List[torch.Tensor]]
image_sizes: Optional[List[Tuple[int, int]]]
image_grid_thw: Optional[torch.Tensor]
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches):
batch = super(FlashVlmCausalLMBatch, cls).concatenate(batches)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@tracer.start_as_current_span("filter")
def filter(self, request_ids: List[int]):
batch = super().filter(request_ids)
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
@classmethod
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
# qwen2_vl expects images to be greater than 20 pixels, this is for warmup since the
# default warmup image is 20x20
if config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if image.width <= 20:
w = image.width * 2
h = image.height * 2
image = image.resize((w, h))
if config.model_type == "llava_next":
images.append(image)
elif config.model_type == "gemma3":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
if images:
kwargs = {}
if (
hasattr(processor, "image_processor_class")
and processor.image_processor_class == "Idefics3ImageProcessor"
):
kwargs["return_row_col_info"] = True
image_inputs = processor.image_processor(
images, return_tensors="pt", **kwargs
)
else:
image_inputs = None
batch_tokenized_inputs = []
max_length = 0
image_id = 0
for r in requests:
full_text = ""
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
full_text += image_text_replacement(
processor, image_inputs, config, image_id
)
image_id += 1
full_text = image_text_replacement_fixup(config, full_text)
input_ids = tokenizer(
full_text,
truncation=True,
max_length=r.truncate,
add_special_tokens=r.add_special_tokens,
)["input_ids"]
max_length = max(max_length, len(input_ids))
batch_tokenized_inputs.append(input_ids)
return batch_tokenized_inputs, image_inputs
@classmethod
def from_pb_processor(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor,
config,
dtype: torch.dtype,
device: torch.device,
) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config
)
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
if image_inputs is not None:
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
if "pixel_attention_mask" in image_inputs:
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
device=device
)
else:
batch.pixel_attention_mask = None
if "image_sizes" in image_inputs:
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
else:
batch.image_sizes = None
if "image_grid_thw" in image_inputs:
batch.image_grid_thw = image_inputs["image_grid_thw"].to(device=device)
else:
batch.image_grid_thw = None
else:
batch.pixel_values = None
batch.pixel_attention_mask = None
batch.image_sizes = None
batch.image_grid_thw = None
return batch
class FlashVlmCausalLM(FlashCausalLM):
def __init__(
self,
model_id: str,
*,
processor_class=AutoProcessor,
processor_kwargs=None,
batch_class=FlashVlmCausalLMBatch,
revision,
trust_remote_code: bool,
**kwargs,
):
if PREFIX_CACHING:
raise NotImplementedError("Vlm do not work with prefix caching yet")
if processor_kwargs is None:
processor_kwargs = {}
self.processor = processor_class.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
**processor_kwargs,
)
self.batch_class = batch_class
super().__init__(
model_id=model_id,
revision=revision,
trust_remote_code=trust_remote_code,
# FIXME: VLM do not work with context chunking yet
support_chunking=False,
**kwargs,
)
@property
def batch_type(self) -> Type[FlashVlmCausalLMBatch]:
return self.batch_class
def max_past(self) -> Optional[int]:
return getattr(self.model.text_model, "max_past", None)
def forward(
self,
batch: FlashVlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids
B, speculative_length = speculative_ids.shape
new_length = speculative_length + 1
new_input_ids = torch.cat(
[input_ids.unsqueeze(-1), speculative_ids], dim=1
).reshape(-1)
arange = torch.arange(new_length, device=position_ids.device).unsqueeze(0)
arange_int = arange.to(dtype=torch.int32)
new_position_ids = (
position_ids.unsqueeze(-1).expand(B, new_length) + arange
).view(-1)
slots = (slots.unsqueeze(-1).expand(B, new_length) + arange_int).view(-1)
input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1)
cache_lengths_tensor = (
batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1)
# Add Copy the block tables for all members
block_tables = (
block_tables.unsqueeze(1)
.expand(B, new_length, -1)
.reshape(B * new_length, -1)
.contiguous()
)
max_s = max_s + speculative_length
input_ids = new_input_ids
position_ids = new_position_ids
else:
input_ids = batch.input_ids
position_ids = batch.position_ids
cu_seqlen_prefill = batch.cu_seqlen_prefill
kv_cache = self.kv_cache
block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor
cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices
if self.model.config.model_type in {"qwen2_vl", "qwen2_5_vl"}:
if position_ids.dim() == 1 and batch.prefilling:
position_ids = self.model.get_position_ids(
input_ids, batch.image_grid_thw
)
batch.position_ids = position_ids
if cu_seqlen_prefill is None and self.max_past() is not None:
# In decode, not prefill, we're actually overwriting the KV-cache
# in a circular buffer mode.
# This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s)
kwargs = {}
if htorch.utils.internal.is_lazy():
kwargs["bypass_hpu_graphs"] = False
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
)
if batch.prefill_cache_indices is not None:
slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices,
pixel_values=batch.pixel_values,
pixel_attention_mask=batch.pixel_attention_mask,
image_sizes=batch.image_sizes,
image_grid_thw=batch.image_grid_thw,
**kwargs,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
if batch.pixel_attention_mask is not None:
batch.pixel_attention_mask = None
if batch.image_sizes is not None:
batch.image_sizes = None
if batch.image_grid_thw is not None:
batch.image_grid_thw = None
return logits, speculative_logits

View File

@ -1,53 +1,31 @@
import torch
import os import os
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger from loguru import logger
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
ATTENTION = os.getenv("ATTENTION", "default") ATTENTION = os.getenv("ATTENTION", "default")
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0" # default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in { PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
"1", "1",
"true", "true",
} }
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}") log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
_expected = {"paged", "flashdecoding", "flashinfer", "default"} _expected = {"paged", "default"}
assert ( assert (
ATTENTION in _expected ATTENTION in _expected
), f"Attention is not valid {ATTENTION}, expected {_expected}" ), f"Attention is not valid {ATTENTION}, expected {_expected}"
log_master(logger.info, f"Using Attention = {ATTENTION}") log_master(logger.info, f"Using Attention = {ATTENTION}")
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
raise RuntimeError("Prefix caching is only supported with flashinfer")
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90")) TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
assert TGI_WIGGLE_ROOM > 0 assert TGI_WIGGLE_ROOM > 0
assert TGI_WIGGLE_ROOM < 1 assert TGI_WIGGLE_ROOM < 1
# This is overridden by the cli # This is overridden by the cli
BLOCK_SIZE: int BLOCK_SIZE: int
if ATTENTION == "flashdecoding":
BLOCK_SIZE = 256
elif ATTENTION == "flashinfer":
BLOCK_SIZE = 1
else:
BLOCK_SIZE = 16
# This is overridden by the cli BLOCK_SIZE = 128
cuda_graphs = os.getenv("CUDA_GRAPHS")
if cuda_graphs is not None:
try:
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
except Exception as e:
raise RuntimeError(
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
)
else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading. # This is overridden at model loading.
global MODEL_ID global MODEL_ID

View File

@ -34,9 +34,6 @@ from text_generation_server.utils import (
) )
from text_generation_server.utils.quantization import get_loader from text_generation_server.utils.quantization import get_loader
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -596,22 +593,8 @@ class IdeficsCausalLM(Model):
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available(): device = torch.device("hpu")
device = torch.device(f"cuda:{rank}") dtype = torch.bfloat16 if dtype is None else dtype
# 9b seems to work correctly enough in float16, but 80b seems
# to be really saturating for f16.
dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype self.device, self.dtype = device, dtype
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(

View File

@ -1,28 +1,30 @@
from io import BytesIO
from PIL import Image
import torch import torch
import numpy as np
from typing import Iterable, Optional, Tuple, List, Dict from typing import Iterable, Optional, Tuple, List, Dict
from text_generation_server.pb.generate_pb2 import Request from text_generation_server.pb.generate_pb2 import Request
from io import BytesIO
from PIL import Image
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
) )
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
from text_generation_server.pb import generate_pb2
from text_generation_server.models.flash_causal_lm import (
block_tables_to_ragged,
)
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
FlashVlmCausalLM,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.layers.attention import Seqlen, trim_seqlen_metadata
import habana_frameworks.torch as htorch
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class MllamaCausalLMBatch(VlmCausalLMBatch): class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
image_indices: List[int] = 42 image_indices: List[int] = 42
aspect_ratio_ids: Optional[torch.Tensor] = None aspect_ratio_ids: Optional[torch.Tensor] = None
aspect_ratio_mask: Optional[torch.Tensor] = None aspect_ratio_mask: Optional[torch.Tensor] = None
@ -158,7 +160,7 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
config, config,
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device, device: torch.device,
) -> "VlmCausalLMBatch": ) -> "FlashVlmCausalLMBatch":
batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs( batch_tokenized_inputs, image_inputs = cls.batch_tokenized_inputs(
pb.requests, tokenizer, processor, config pb.requests, tokenizer, processor, config
) )
@ -167,6 +169,13 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp( batch.all_input_ids_tensor = batch.all_input_ids_tensor.clamp(
max=config.text_config.vocab_size - 1 max=config.text_config.vocab_size - 1
) )
if isinstance(batch.input_ids, list):
if len(batch) > 1:
input_ids = np.concatenate(batch.input_ids, dtype=np.int64)
else:
input_ids = batch.input_ids[0]
batch.input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1) batch.input_ids = batch.input_ids.clamp(max=config.text_config.vocab_size - 1)
if image_inputs is not None: if image_inputs is not None:
@ -187,10 +196,10 @@ class MllamaCausalLMBatch(VlmCausalLMBatch):
return batch return batch
class MllamaCausalLM(VlmCausalLM): class FlashMllamaCausalLM(FlashVlmCausalLM):
def forward( def forward(
self, self,
batch: VlmCausalLMBatch, batch: FlashMllamaCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None, adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
@ -202,7 +211,7 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
speculative_ids = batch.speculative_ids speculative_ids = batch.speculative_ids
@ -221,8 +230,8 @@ class MllamaCausalLM(VlmCausalLM):
input_lengths = ( input_lengths = (
input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int input_lengths.unsqueeze(-1).expand(B, new_length) + arange_int
).view(-1) ).view(-1)
prefix_lens_tensor = ( cache_lengths_tensor = (
batch.prefix_lens_tensor.unsqueeze(-1).expand(B, new_length) batch.cache_lengths_tensor.unsqueeze(-1).expand(B, new_length)
).reshape(-1) ).reshape(-1)
# Add Copy the block tables for all members # Add Copy the block tables for all members
@ -244,8 +253,8 @@ class MllamaCausalLM(VlmCausalLM):
block_tables = batch.block_tables_tensor block_tables = batch.block_tables_tensor
slots = batch.slots[batch.slot_indices] slots = batch.slots[batch.slot_indices]
input_lengths = batch.input_lengths_tensor input_lengths = batch.input_lengths_tensor
prefix_lens_tensor = batch.prefix_lens_tensor cache_lengths_tensor = batch.cache_lengths_tensor
max_s = batch.max_seqlen max_s = batch.max_current_length
lm_head_indices = batch.prefill_head_indices lm_head_indices = batch.prefill_head_indices
if cu_seqlen_prefill is None and self.max_past() is not None: if cu_seqlen_prefill is None and self.max_past() is not None:
@ -254,104 +263,46 @@ class MllamaCausalLM(VlmCausalLM):
# This makes sure the max_s for the decode pass is correct. # This makes sure the max_s for the decode pass is correct.
max_s = min(self.max_past(), max_s) max_s = min(self.max_past(), max_s)
bs = input_ids.shape[0] seqlen = Seqlen(
# Try to find an associated cuda graph input_lengths=input_lengths,
bs = input_ids.shape[0] cache_lengths=cache_lengths_tensor,
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) cu_seqlen_q=cu_seqlen_prefill,
if sorted_padded_bs: )
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None
if (
cu_seqlen_prefill is not None
or cuda_graph is None
# Only run cuda graphs when there's no images.
or batch.cross_attention_states is not None
):
input_lengths = input_lengths + prefix_lens_tensor
if PREFIX_CACHING:
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
)
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=cu_seqlen_prefill,
input_lengths_tensor=input_lengths,
prefix_lens_tensor=prefix_lens_tensor,
):
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
if batch.pixel_values is not None: if batch.pixel_values is not None:
cross_attention_states = self.model.vision_forward( cross_attention_states = self.model.vision_forward(
pixel_values=batch.pixel_values, pixel_values=batch.pixel_values,
aspect_ratio_ids=batch.aspect_ratio_ids, aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask, aspect_ratio_mask=batch.aspect_ratio_mask,
)
batch.cross_attention_states = cross_attention_states
cross_attention_states = batch.cross_attention_states
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
adapter_data=adapter_data,
image_indices=batch.image_indices[:],
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=batch.input_lengths,
prefix_lens=batch.prefix_lens,
) )
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables batch.cross_attention_states = cross_attention_states
else:
cuda_graph["block_tables"][
: block_tables.shape[0], : block_tables.shape[1]
] = block_tables
cuda_graph["slots"].fill_(0)
cuda_graph["slots"][: slots.shape[0]] = slots
cuda_graph["input_lengths"].zero_()
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
input_lengths + prefix_lens_tensor
)
# Replay the graph cross_attention_states = batch.cross_attention_states
cuda_graph["graph"].replay()
# Slice output to the correct shape kwargs = {}
speculative_logits = ( if htorch.utils.internal.is_lazy():
cuda_graph["speculative_logits"][:bs] kwargs["bypass_hpu_graphs"] = False
if cuda_graph["speculative_logits"] is not None if batch.prefill_cache_indices is not None:
else None slots_pad = torch.zeros_like(input_ids)
slots_pad[batch.prefill_cache_indices] = slots
slots = slots_pad
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
slots=slots,
seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
# TODO list
adapter_data=None,
image_indices=batch.image_indices[:],
**kwargs,
) )
logits = cuda_graph["logits"][:bs] if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
if batch.pixel_values is not None:
batch.pixel_values = None
return logits, speculative_logits return logits, speculative_logits

View File

@ -33,6 +33,7 @@ class Model(ABC):
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
speculate: Optional[int] = None, speculate: Optional[int] = None,
adapter_id: str = BASE_MODEL_ADAPTER_ID, adapter_id: str = BASE_MODEL_ADAPTER_ID,
support_chunking: bool = False,
): ):
self.model_id = model_id self.model_id = model_id
self.model = model.eval() self.model = model.eval()

View File

@ -4,8 +4,8 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Iterable from typing import Iterable
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.flash_vlm_causal_lm import (
VlmCausalLMBatch, FlashVlmCausalLMBatch,
image_text_replacement, image_text_replacement,
) )
@ -14,7 +14,7 @@ from text_generation_server.pb.generate_pb2 import Request
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
class PaliGemmaBatch(VlmCausalLMBatch): class PaliGemmaBatch(FlashVlmCausalLMBatch):
@classmethod @classmethod
def batch_tokenized_inputs( def batch_tokenized_inputs(
cls, requests: Iterable[Request], tokenizer, processor, config cls, requests: Iterable[Request], tokenizer, processor, config

View File

@ -10,7 +10,6 @@ from transformers import (
AutoConfig, AutoConfig,
) )
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
@ -555,20 +554,9 @@ class Seq2SeqLM(Model):
): ):
self.quantize = quantize self.quantize = quantize
self.process_group, rank, world_size = initialize_torch_distributed() self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device("hpu")
dtype = default_dtype if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
config = config_class.from_pretrained( config = config_class.from_pretrained(
model_id, model_id,
@ -600,7 +588,7 @@ class Seq2SeqLM(Model):
aliases=aliases, aliases=aliases,
weights_loader=weights_loader, weights_loader=weights_loader,
) )
if config.quantize in ["awq", "exl2", "gptq", "marlin"]: if config.quantize in ["awq", "gptq"]:
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
model = model_class(config, weights) model = model_class(config, weights)

View File

@ -69,11 +69,7 @@ MAX_TOTAL_TOKENS = int(os.environ.get("MAX_TOTAL_TOKENS", 8192))
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128)) PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 128))
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1)) LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
if max_batch_size_str is not None:
MAX_BATCH_SIZE = int(max_batch_size_str)
else:
raise ValueError("MAX_BATCH_SIZE is not set")
PREFILL_WARMUP_BATCH_SIZE_LIST = [] PREFILL_WARMUP_BATCH_SIZE_LIST = []
PREFILL_WARMUP_SEQLEN_LIST = [] PREFILL_WARMUP_SEQLEN_LIST = []
@ -1467,6 +1463,12 @@ class VlmCausalLM(Model):
batch = self.batch_from_pb(request.batch, is_warmup=True) batch = self.batch_from_pb(request.batch, is_warmup=True)
max_input_tokens = request.max_input_tokens max_input_tokens = request.max_input_tokens
max_prefill_batch_size = batch.input_ids.shape[0] max_prefill_batch_size = batch.input_ids.shape[0]
max_batch_size_str = os.environ.get("MAX_BATCH_SIZE")
if max_batch_size_str is not None:
MAX_BATCH_SIZE = int(max_batch_size_str)
else:
raise ValueError("MAX_BATCH_SIZE is not set")
try: try:
# max prefill batch size warmup # max prefill batch size warmup
_, prefill_batch, _ = self.generate_token([batch], is_warmup=True) _, prefill_batch, _ = self.generate_token([batch], is_warmup=True)

View File

@ -18,22 +18,27 @@ from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model_with_lora_adapters from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.globals import set_model_id from text_generation_server.models.globals import set_model_id, ATTENTION
from text_generation_server.models.globals import set_adapter_to_index from text_generation_server.models.globals import set_adapter_to_index
from text_generation_server.utils.adapter import AdapterInfo from text_generation_server.utils.adapter import AdapterInfo
from text_generation_server.utils.tokens import make_tokenizer_optional from text_generation_server.utils.tokens import make_tokenizer_optional
from text_generation_server.utils.prefill_chunking import set_max_prefill_tokens
try: try:
from text_generation_server.models.pali_gemma import PaliGemmaBatch from text_generation_server.models.pali_gemma import PaliGemmaBatch
from text_generation_server.models.mllama_causal_lm import FlashMllamaCausalLMBatch
from text_generation_server.models.vlm_causal_lm import ( from text_generation_server.models.vlm_causal_lm import (
VlmCausalLMBatch, VlmCausalLMBatch,
) )
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.flash_vlm_causal_lm import (
FlashVlmCausalLMBatch,
)
VLM_BATCH_TYPES = { VLM_BATCH_TYPES = {
PaliGemmaBatch, PaliGemmaBatch,
VlmCausalLMBatch, VlmCausalLMBatch,
IdeficsCausalLMBatch, FlashVlmCausalLMBatch,
FlashMllamaCausalLMBatch,
} }
except (ImportError, NotImplementedError): except (ImportError, NotImplementedError):
# These imports can fail on CPU/Non flash. # These imports can fail on CPU/Non flash.
@ -103,14 +108,50 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context): async def Warmup(self, request, context):
max_supported_total_tokens, max_input_tokens, max_total_tokens = ( if ATTENTION == "paged":
self.model.warmup(request) set_max_prefill_tokens(request.max_prefill_tokens)
) if (
self.model.batch_type in VLM_BATCH_TYPES
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb_processor(
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.model.config,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
request.batch,
self.model.tokenizer,
self.model.dtype,
self.model.device,
)
# W/A for the skip tokenizer path # Override default values with None for clearer semantics.
# We need to call make_tokenizer_optional after the warmup, max_input_tokens = (
# because router is not aware of that feature request.max_input_tokens
make_tokenizer_optional(self.model.tokenizer) if request.HasField("max_input_tokens")
else None
)
max_total_tokens = (
request.max_total_tokens
if request.HasField("max_total_tokens")
else None
)
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(batch, max_input_tokens, max_total_tokens)
)
else:
max_supported_total_tokens, max_input_tokens, max_total_tokens = (
self.model.warmup(request)
)
# W/A for the skip tokenizer path
# We need to call make_tokenizer_optional after the warmup,
# because router is not aware of that feature
make_tokenizer_optional(self.model.tokenizer)
return generate_pb2.WarmupResponse( return generate_pb2.WarmupResponse(
max_supported_total_tokens=max_supported_total_tokens, max_supported_total_tokens=max_supported_total_tokens,

View File

@ -1,15 +1,13 @@
import os import os
import torch import torch
from torch.distributed import ProcessGroup
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
# CUDA memory fraction
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
class FakeBarrier: class FakeBarrier:
@ -17,10 +15,11 @@ class FakeBarrier:
pass pass
class FakeGroup: class FakeGroup(ProcessGroup):
def __init__(self, rank, size): def __init__(self, rank, size):
self._rank = rank self._rank = rank
self._size = size self._size = size
super().__init__(rank, size)
def allreduce(self, *args, **kwargs): def allreduce(self, *args, **kwargs):
return FakeBarrier() return FakeBarrier()
@ -42,42 +41,11 @@ class FakeGroup:
def rank(self): def rank(self):
return self._rank return self._rank
def _get_backend_name(self):
return "fake"
def initialize_torch_distributed(): def initialize_torch_distributed():
world_size = int(os.getenv("WORLD_SIZE", "1"))
options = None
if torch.cuda.is_available():
from torch.distributed import ProcessGroupNCCL
# Set the device id.
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
device = RANK % torch.cuda.device_count()
torch.cuda.set_device(device)
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
backend = "nccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
elif torch.hpu.is_available():
backend = "hccl"
n_hpus = torch.hpu.device_count()
if world_size > n_hpus:
raise ValueError(
f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})."
)
else:
try:
import oneccl_bindings_for_pytorch # noqa: F401
backend = "ccl"
if os.getenv("CCL_WORKER_COUNT", None) is None:
os.environ["CCL_WORKER_COUNT"] = str(1)
except ImportError:
backend = "gloo"
options = None
if WORLD_SIZE == 1: if WORLD_SIZE == 1:
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
else: else:
@ -87,11 +55,10 @@ def initialize_torch_distributed():
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
# Call the init process. # Call the init process.
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend="hccl",
world_size=WORLD_SIZE, world_size=WORLD_SIZE,
rank=RANK, rank=RANK,
timeout=timedelta(seconds=60), timeout=timedelta(seconds=120),
pg_options=options,
) )
else: else:
logger.warning("torch.distributed is already initialized.") logger.warning("torch.distributed is already initialized.")

View File

@ -1,75 +1,28 @@
import torch import torch
from loguru import logger from loguru import logger
import os
import importlib.util def get_hpu_free_memory(device, memory_fraction):
from habana_frameworks.torch.hpu import memory_stats
def is_ipex_available():
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
def get_cuda_free_memory(device, memory_fraction):
total_free_memory, _ = torch.cuda.mem_get_info(device)
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
return free_memory
def get_xpu_free_memory(device, memory_fraction):
total_memory = torch.xpu.get_device_properties(device).total_memory
device_id = device.index device_id = device.index
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0")) mem_stats = memory_stats(device_id)
logger.info(f"mem_stats: {mem_stats}")
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
free_memory = max( free_memory = max(
0, 0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
int(
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
),
) )
return free_memory return free_memory
def get_cpu_free_memory(device, memory_fraction): def synchronize_hpu(device):
import psutil torch.hpu.synchronize()
from text_generation_server.utils.dist import WORLD_SIZE
mem = psutil.virtual_memory()
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
return free_memory
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
SYSTEM = None empty_cache = noop
if torch.version.hip is not None: synchronize = synchronize_hpu
SYSTEM = "rocm" get_free_memory = get_hpu_free_memory
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None and torch.cuda.is_available():
SYSTEM = "cuda"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif is_ipex_available():
SYSTEM = "ipex"
import intel_extension_for_pytorch # noqa: F401
if hasattr(torch, "xpu") and torch.xpu.is_available():
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
else:
SYSTEM = "cpu"
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
logger.info(f"Detected system {SYSTEM}")

View File

@ -0,0 +1,22 @@
import importlib
from loguru import logger
from hf_kernels import load_kernel as hf_load_kernel
from text_generation_server.utils.log import log_once
def load_kernel(*, module: str, repo_id: str):
"""
Load a kernel. First try to load it as the given module (e.g. for
local development), falling back to a locked Hub kernel.
"""
try:
m = importlib.import_module(module)
log_once(logger.info, f"Using local module for `{module}`")
return m
except ModuleNotFoundError:
return hf_load_kernel(repo_id=repo_id)
__all__ = ["load_kernel"]

View File

@ -0,0 +1,24 @@
from typing import Optional
SUPPORT_CHUNKING: Optional[bool] = None
MAX_PREFILL_TOKENS: Optional[int] = None
def set_support_chunking(support_chunking: bool):
global SUPPORT_CHUNKING
SUPPORT_CHUNKING = support_chunking
def get_support_chunking() -> bool:
global SUPPORT_CHUNKING
return SUPPORT_CHUNKING
def set_max_prefill_tokens(max_prefill_tokens: int):
global MAX_PREFILL_TOKENS
MAX_PREFILL_TOKENS = max_prefill_tokens
def get_max_prefill_tokens() -> int:
global MAX_PREFILL_TOKENS
return MAX_PREFILL_TOKENS

View File

@ -4,9 +4,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from text_generation_server.layers.marlin.gptq import can_use_gptq_marlin
from text_generation_server.utils.weights import ( from text_generation_server.utils.weights import (
DefaultWeightsLoader,
WeightsLoader, WeightsLoader,
) )
@ -129,64 +127,13 @@ def get_loader(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config." f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
) )
if can_use_gptq_marlin( return GPTQWeightsLoader(
bits=quantizer_config.bits, bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize, groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method, quant_method=quantizer_config.quant_method,
quantize=quantize, quantize=quantize,
sym=quantizer_config.sym, sym=quantizer_config.sym,
):
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
return GPTQMarlinWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
else:
return GPTQWeightsLoader(
bits=quantizer_config.bits,
desc_act=quantizer_config.desc_act,
groupsize=quantizer_config.groupsize,
quant_method=quantizer_config.quant_method,
quantize=quantize,
sym=quantizer_config.sym,
)
elif quantize == "bitsandbytes":
from text_generation_server.layers.bnb import BNBWeight
return DefaultWeightsLoader(BNBWeight)
elif quantize == "bitsandbytes-fp4":
from text_generation_server.layers.bnb import BNBFP4Weight
return DefaultWeightsLoader(BNBFP4Weight)
elif quantize == "bitsandbytes-nf4":
from text_generation_server.layers.bnb import BNBNF4Weight
return DefaultWeightsLoader(BNBNF4Weight)
elif quantize == "eetq":
from text_generation_server.layers.eetq import EETQWeight
return DefaultWeightsLoader(EETQWeight)
elif quantize == "exl2":
from text_generation_server.layers.exl2 import Exl2WeightsLoader
return Exl2WeightsLoader()
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeightsLoader
# TODO: improve check once we have one config type per quantize value
if not isinstance(quantizer_config, _QuantizerConfig):
raise ValueError(
f"Quantize is set to `{quantize}` but received a `{quantizer_config.__class__.__name__}` config."
)
return MarlinWeightsLoader(
bits=quantizer_config.bits,
is_marlin_24=quantizer_config.checkpoint_format == "marlin_24",
) )
elif quantize == "fp8" or quantize is None: elif quantize == "fp8" or quantize is None:
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader from text_generation_server.layers.fp8 import HybridFP8UnquantLoader

View File

@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type
from safetensors import safe_open from safetensors import safe_open
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
class WeightsLoader(ABC): class WeightsLoader(ABC):
""" """
@ -88,12 +86,9 @@ class UnquantizedWeight(Weight):
weight: torch.Tensor weight: torch.Tensor
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
from text_generation_server.layers.linear import FastLinear, FastLinearROCm from text_generation_server.layers.linear import FastLinear
if SYSTEM == "rocm": return FastLinear(self.weight, bias)
return FastLinearROCm(self.weight, bias)
else:
return FastLinear(self.weight, bias)
class DefaultWeightsLoader(WeightsLoader): class DefaultWeightsLoader(WeightsLoader):
@ -197,7 +192,7 @@ class Weights:
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
return slice_ return slice_
def _has_tensor(self, tensor_name: str): def has_tensor(self, tensor_name: str):
try: try:
self.get_filename(tensor_name) self.get_filename(tensor_name)
except Exception: except Exception:
@ -207,7 +202,9 @@ class Weights:
def get_shape(self, tensor_name: str): def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape() return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True): def get_tensor(
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
) -> torch.Tensor:
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
@ -218,6 +215,7 @@ class Weights:
tensor.dtype tensor.dtype
not in [ not in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.int8,
torch.int16, torch.int16,
torch.int32, torch.int32,
torch.int64, torch.int64,
@ -253,7 +251,8 @@ class Weights:
# u4 which are disguised as int32. exl2 uses int16. # u4 which are disguised as int32. exl2 uses int16.
# FP8 uses torch.float8_e4m3fn. # FP8 uses torch.float8_e4m3fn.
if ( if (
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32) tensor.dtype
not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
and to_dtype and to_dtype
): ):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
@ -329,6 +328,7 @@ class Weights:
tensor.dtype tensor.dtype
not in [ not in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.int8,
torch.int16, torch.int16,
torch.int32, torch.int32,
torch.int64, torch.int64,

View File

@ -2,7 +2,7 @@ use std::sync::Arc;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use crate::radix::RadixAllocator; use crate::radix::RadixAllocator;
use text_generation_router::usage_stats::Env;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct BlockAllocation { pub struct BlockAllocation {
pub allocation_id: u64, pub allocation_id: u64,
@ -141,6 +141,7 @@ pub struct SimpleAllocator {
free_blocks: Vec<u32>, free_blocks: Vec<u32>,
block_size: u32, block_size: u32,
window_size: Option<u32>, window_size: Option<u32>,
is_hpu_device: bool,
} }
impl SimpleAllocator { impl SimpleAllocator {
@ -150,6 +151,7 @@ impl SimpleAllocator {
// Block 0 is reserved for health checks // Block 0 is reserved for health checks
free_blocks: (1..blocks).collect(), free_blocks: (1..blocks).collect(),
window_size, window_size,
is_hpu_device: Env::new().is_hpu_device(),
} }
} }
} }
@ -179,9 +181,15 @@ impl Allocator for SimpleAllocator {
if required_blocks > self.free_blocks.len() as u32 { if required_blocks > self.free_blocks.len() as u32 {
None None
} else { } else {
let blocks = self if self.is_hpu_device {
self.free_blocks.sort_by(|a, b| b.cmp(a));
}
let mut blocks = self
.free_blocks .free_blocks
.split_off(self.free_blocks.len() - required_blocks as usize); .split_off(self.free_blocks.len() - required_blocks as usize);
if self.is_hpu_device {
blocks.sort();
}
let mut slots = let mut slots =
Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize); Vec::with_capacity((required_blocks * self.block_size * repeats as u32) as usize);

View File

@ -28,8 +28,8 @@ impl Env {
} }
} }
pub fn is_hpu_device(&self) -> bool { pub fn should_start_a_single_hpu_shard(&self) -> bool {
self.hpu_env != "N/A" self.hpu_env != "N/A" && std::env::var("ATTENTION").as_deref() != Ok("paged")
} }
} }

View File

@ -1559,7 +1559,7 @@ fn spawn_shards(
) -> Result<(), LauncherError> { ) -> Result<(), LauncherError> {
// Start shard processes // Start shard processes
for rank in 0..num_shard { for rank in 0..num_shard {
if rank != 0 && env_runtime::Env::new().is_hpu_device() { if rank != 0 && env_runtime::Env::new().should_start_a_single_hpu_shard() {
tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server"); tracing::info!("Running on HPU, the launcher will not do any sharding as actual sharding is done in the server");
break; break;
} }
@ -1639,7 +1639,7 @@ fn spawn_shards(
if shard_ready == num_shard { if shard_ready == num_shard {
break; break;
} }
if env_runtime::Env::new().is_hpu_device() { if env_runtime::Env::new().should_start_a_single_hpu_shard() {
tracing::info!("HPU detected, shard is ready"); tracing::info!("HPU detected, shard is ready");
break; break;
} }

View File

@ -157,6 +157,7 @@ pub struct Env {
docker_label: &'static str, docker_label: &'static str,
nvidia_info: Option<Vec<NvidiaSmiInfo>>, nvidia_info: Option<Vec<NvidiaSmiInfo>>,
xpu_info: Option<Vec<XpuSmiInfo>>, xpu_info: Option<Vec<XpuSmiInfo>>,
hpu_info: Option<Vec<HpuSmiInfo>>,
system_env: SystemInfo, system_env: SystemInfo,
} }
@ -289,6 +290,60 @@ impl XpuSmiInfo {
} }
} }
#[derive(Debug, Serialize, Clone)]
struct HpuSmiInfo {
name: String,
pci_bus_id: String,
driver_version: String,
temperature: String,
utilization: String,
memory_total: String,
memory_free: String,
memory_used: String,
power_draw_instant: String,
}
impl HpuSmiInfo {
fn new() -> Option<Vec<HpuSmiInfo>> {
let output = Command::new("hl-smi")
.args([
"--query-aip=name,bus_id,driver_version,temperature.aip,utilization.aip,memory.total,memory.free,memory.used,power.draw",
"--format=csv"
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let stdout = String::from_utf8(output.stdout).ok()?;
let mut rdr = ReaderBuilder::new()
.has_headers(true)
.from_reader(stdout.as_bytes());
let mut infos = Vec::new();
for result in rdr.records() {
let record = result.ok()?;
infos.push(HpuSmiInfo {
name: record[0].to_string(),
pci_bus_id: record[1].to_string(),
driver_version: record[2].to_string(),
temperature: record[3].to_string(),
utilization: record[4].to_string(),
memory_total: record[5].to_string(),
memory_free: record[6].to_string(),
memory_used: record[7].to_string(),
power_draw_instant: record[8].to_string(),
});
}
Some(infos)
}
}
#[derive(Serialize, Debug, Clone)] #[derive(Serialize, Debug, Clone)]
pub struct SystemInfo { pub struct SystemInfo {
cpu_count: usize, cpu_count: usize,
@ -335,10 +390,14 @@ impl Env {
system_env: SystemInfo::new(), system_env: SystemInfo::new(),
nvidia_info: NvidiaSmiInfo::new(), nvidia_info: NvidiaSmiInfo::new(),
xpu_info: XpuSmiInfo::new(), xpu_info: XpuSmiInfo::new(),
hpu_info: HpuSmiInfo::new(),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"), git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"), docker_label: option_env!("DOCKER_LABEL").unwrap_or("N/A"),
} }
} }
pub fn is_hpu_device(&self) -> bool {
self.hpu_info.is_some()
}
} }
pub fn is_container() -> io::Result<bool> { pub fn is_container() -> io::Result<bool> {