mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
clean cuda/rocm code in hpu backend, enable flat_hpu
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
ae4451c3da
commit
201dc6294f
@ -96,7 +96,7 @@ RUN cd server && \
|
||||
pip install "git+https://github.com/HabanaAI/DeepSpeed.git@${HABANA_VERSION}" && \
|
||||
BUILD_CUDA_EXT=0 pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@097dd04e --no-build-isolation && \
|
||||
pip install . --no-cache-dir
|
||||
|
||||
RUN pip install git+https://github.com/sywangyi/vllm-hpu-extension.git
|
||||
# Install benchmarker
|
||||
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||
# Install router
|
||||
|
@ -1,43 +1,28 @@
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
import os
|
||||
from .common import (
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
trim_attn_metadata,
|
||||
trim_seqlen_metadata,
|
||||
)
|
||||
|
||||
from .common import Seqlen
|
||||
|
||||
if os.getenv("USE_FLASH_ATTENTION", "true").lower() == "false":
|
||||
raise ImportError("`USE_FLASH_ATTENTION` is false.")
|
||||
if SYSTEM == "cuda":
|
||||
from .cuda import (
|
||||
from .hpu import (
|
||||
SUPPORTS_WINDOWING,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
SUPPORTS_WINDOWING,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
elif SYSTEM == "rocm":
|
||||
from .rocm import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
elif SYSTEM == "ipex":
|
||||
from .ipex import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
SUPPORTS_WINDOWING,
|
||||
)
|
||||
else:
|
||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||
)
|
||||
|
||||
|
||||
# KVCache needs `reshape_and_cache`, so ensure that it is defined already.
|
||||
from .kv_cache import KVCache, get_kv_scales
|
||||
|
||||
__all__ = [
|
||||
"attention",
|
||||
"get_kv_scales",
|
||||
"paged_attention",
|
||||
"reshape_and_cache",
|
||||
"PREFILL_IN_KV_CACHE",
|
||||
"SUPPORTS_WINDOWING",
|
||||
"KVCache",
|
||||
"Seqlen",
|
||||
"HPUPagedAttentionMetadata",
|
||||
"trim_seqlen_metadata",
|
||||
"trim_attn_metadata",
|
||||
]
|
||||
|
@ -1,31 +1,94 @@
|
||||
from dataclasses import dataclass
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
import torch
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Dict
|
||||
import collections
|
||||
|
||||
_TYPE_CACHE = {}
|
||||
|
||||
|
||||
if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
@dataclass
|
||||
class HPUPagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
block_list: Optional[torch.Tensor]
|
||||
block_mapping: Optional[torch.Tensor]
|
||||
block_usage: Optional[torch.Tensor]
|
||||
block_scales: Optional[torch.Tensor]
|
||||
block_groups: Optional[torch.Tensor]
|
||||
attn_bias: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def subtuple(
|
||||
obj: object,
|
||||
typename: str,
|
||||
to_copy: List[str],
|
||||
to_override: Optional[Dict[str, object]] = None,
|
||||
):
|
||||
if obj is None:
|
||||
return None
|
||||
if to_override is None:
|
||||
to_override = {}
|
||||
fields = set(to_copy) | set(to_override.keys())
|
||||
if isinstance(obj, dict):
|
||||
values = {key: obj[key] for key in fields if key in obj}
|
||||
else:
|
||||
values = {f: to_override.get(f, getattr(obj, f)) for f in fields}
|
||||
if typename not in _TYPE_CACHE:
|
||||
_TYPE_CACHE[typename] = collections.namedtuple(typename, " ".join(fields))
|
||||
return _TYPE_CACHE[typename](**values)
|
||||
|
||||
|
||||
def trim_attn_metadata(metadata: HPUPagedAttentionMetadata) -> object:
|
||||
# NOTE(kzawora): To anyone working on this in the future:
|
||||
# Trimming metadata is required when using HPUGraphs.
|
||||
# Attention metadata is going to be hashed by PT bridge, and
|
||||
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||
|
||||
# Before you put more keys in here, make sure you know their
|
||||
# value type and make sure you know how it's going to be hashed.
|
||||
# You can find that information in input_hash function
|
||||
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||
|
||||
# If you use primitive types here - they will get hashed based
|
||||
# on their value. You *will* get lots of excessive graph captures
|
||||
# (and an OOM eventually) if you decide to put something like
|
||||
# seq_len int here.
|
||||
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||
# get hashed using their metadata, not their values:
|
||||
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||
# input_hash(123) != input_hash(321)
|
||||
# input_hash("abc") != input_hash("cba")
|
||||
attention_metadata = subtuple(
|
||||
metadata,
|
||||
"TrimmedAttentionMetadata",
|
||||
[
|
||||
"block_list",
|
||||
"block_mapping",
|
||||
"block_usage",
|
||||
"block_scales",
|
||||
"block_groups",
|
||||
"attn_bias",
|
||||
],
|
||||
)
|
||||
return attention_metadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cache_lengths: torch.Tensor
|
||||
cu_seqlen_q: Optional[torch.Tensor]
|
||||
cu_seqlen_k: Optional[torch.Tensor]
|
||||
max_q: int
|
||||
max_k: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_lengths,
|
||||
prefix_lengths,
|
||||
cache_lengths,
|
||||
cu_seqlen_q=None,
|
||||
max_q=None,
|
||||
max_k=None,
|
||||
):
|
||||
self.input_lengths = input_lengths
|
||||
self.prefix_lengths = prefix_lengths
|
||||
self.cache_lengths = cache_lengths
|
||||
device = self.input_lengths.device
|
||||
shape = self.input_lengths.shape
|
||||
if cu_seqlen_q is None:
|
||||
@ -34,39 +97,51 @@ if ATTENTION in {"flashinfer", "flashdecoding"}:
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
max_q = 1
|
||||
else:
|
||||
assert max_q is not None
|
||||
assert max_k is not None
|
||||
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
|
||||
|
||||
# cuda graphs don't like this and this is necessary to clamp within mistral
|
||||
# Although FA2 might not want the clamping
|
||||
# cu_seqlen_k[0] = 0
|
||||
total = self.input_lengths + self.prefix_lengths
|
||||
total = self.input_lengths + self.cache_lengths
|
||||
torch.cumsum(total, -1, out=cu_seqlen_k[1:])
|
||||
|
||||
self.cu_seqlen_q = cu_seqlen_q
|
||||
self.cu_seqlen_k = cu_seqlen_k
|
||||
self.max_q = max_q
|
||||
self.max_k = max_k
|
||||
|
||||
def clamp(self, max):
|
||||
# Flash decoding doesn't need to clamp
|
||||
return self
|
||||
|
||||
else:
|
||||
|
||||
@dataclass
|
||||
class Seqlen:
|
||||
input_lengths: torch.Tensor
|
||||
prefix_lengths: torch.Tensor
|
||||
cu_seqlen_q: torch.Tensor
|
||||
max_q: int
|
||||
max_k: int
|
||||
def trim_seqlen_metadata(metadata: Seqlen) -> object:
|
||||
# NOTE(kzawora): To anyone working on this in the future:
|
||||
# Trimming metadata is required when using HPUGraphs.
|
||||
# Attention metadata is going to be hashed by PT bridge, and
|
||||
# appropriate HPUGraphs will be matched based on all inputs' hash.
|
||||
|
||||
def clamp(self, max):
|
||||
if SYSTEM == "rocm":
|
||||
return self
|
||||
raise NotImplementedError("Not implemented seqlen for paged")
|
||||
return Seqlen(torch.clamp(self.input_lengths, max=max))
|
||||
# Before you put more keys in here, make sure you know their
|
||||
# value type and make sure you know how it's going to be hashed.
|
||||
# You can find that information in input_hash function
|
||||
# in habana_frameworks/torch/hpu/graphs.py. You can also hash
|
||||
# it manually with torch.hpu.graphs.input_hash(attention_metadata)
|
||||
|
||||
# If you use primitive types here - they will get hashed based
|
||||
# on their value. You *will* get lots of excessive graph captures
|
||||
# (and an OOM eventually) if you decide to put something like
|
||||
# seq_len int here.
|
||||
# If you absolutely need a scalar, put it in a tensor. Tensors
|
||||
# get hashed using their metadata, not their values:
|
||||
# input_hash(torch.tensor(123)) == input_hash(torch.tensor(321))
|
||||
# input_hash(123) != input_hash(321)
|
||||
# input_hash("abc") != input_hash("cba")
|
||||
attention_metadata = subtuple(
|
||||
metadata,
|
||||
"TrimmedSeqlen",
|
||||
[
|
||||
"input_lengths",
|
||||
"cache_lengths",
|
||||
"cu_seqlen_q",
|
||||
"cu_seqlen_k",
|
||||
],
|
||||
)
|
||||
return attention_metadata
|
||||
|
@ -1,357 +0,0 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import (
|
||||
ATTENTION,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if ATTENTION in {"flashdecoding", "flashinfer"}:
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
cache_ops.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots, "auto", 1.0
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
# block_size = value_cache.shape[3]
|
||||
block_size = BLOCK_SIZE
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import decode_state
|
||||
|
||||
return decode_state.get().forward(
|
||||
query.contiguous(),
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
)
|
||||
elif ATTENTION == "flashdecoding":
|
||||
max_q = 1
|
||||
max_k = max_s
|
||||
import flash_attn_2_cuda
|
||||
|
||||
# TODO fixme when flash contains the fix.
|
||||
# Number of splits is not correctly handled
|
||||
# by the current path
|
||||
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
|
||||
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
|
||||
if softcap is None:
|
||||
softcap = 0.0
|
||||
out = flash_attn_2_cuda.varlen_fwd(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
None,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None, # pad_k
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0, # dropout
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
True, # causal
|
||||
-1, # Window_left
|
||||
-1, # Window right
|
||||
softcap,
|
||||
False, # return softmax
|
||||
None, # generator
|
||||
)
|
||||
return out[0]
|
||||
else:
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
input_lengths = seqlen.input_lengths
|
||||
from vllm._C import ops
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
use_v1 = max_s <= 8192 and (
|
||||
max_num_partitions == 1 or num_seqs * num_heads > 512
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
try:
|
||||
is_ampere_or_newer = major >= 8 and minor >= 0
|
||||
if not is_ampere_or_newer:
|
||||
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
||||
|
||||
import flash_attn_2_cuda
|
||||
|
||||
V2 = True
|
||||
except ImportError:
|
||||
try:
|
||||
import flash_attn_cuda
|
||||
|
||||
V2 = False
|
||||
except ImportError as e:
|
||||
if major >= 8:
|
||||
architecture_suffix = f"-{SYSTEM}"
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||
)
|
||||
elif is_sm75:
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
else:
|
||||
raise ImportError(
|
||||
f"GPU with CUDA capability {major} {minor} is not supported"
|
||||
) from e
|
||||
|
||||
|
||||
SUPPORTS_WINDOWING = V2
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
prefill_with_paged_kv_state,
|
||||
)
|
||||
|
||||
return prefill_with_paged_kv_state.get().forward(
|
||||
q.contiguous(),
|
||||
causal=causal,
|
||||
paged_kv_cache=(key_cache, value_cache),
|
||||
logits_soft_cap=softcap,
|
||||
sm_scale=softmax_scale,
|
||||
window_left=window_size_left,
|
||||
)
|
||||
|
||||
elif V2:
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap=0.0,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
block_tables,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
|
||||
else:
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap=None,
|
||||
):
|
||||
if window_size_left != -1:
|
||||
raise NotImplementedError(
|
||||
"window_size_left is only available with flash attn v2"
|
||||
)
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with flash attn v2")
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if k.shape[1] == 1:
|
||||
k = k.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = k.shape
|
||||
k = (
|
||||
k.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // k.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
if v.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
if v.shape[1] == 1:
|
||||
v = v.expand(-1, q.shape[1], -1)
|
||||
# Grouped attention reshape
|
||||
else:
|
||||
original_shape = v.shape
|
||||
v = (
|
||||
v.unsqueeze(2)
|
||||
.expand(-1, -1, q.shape[1] // v.shape[1], -1)
|
||||
.reshape(original_shape[0], -1, original_shape[2])
|
||||
)
|
||||
|
||||
out = torch.empty_like(q)
|
||||
flash_attn_cuda.fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
0,
|
||||
None,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
# Prefill in the cache with every kind of attention, unless we
|
||||
# have a configuration that requires flash-attention v1, which
|
||||
# does not support block tables.
|
||||
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
|
@ -1,813 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||
|
||||
Features supported:
|
||||
|
||||
1) Fwd with causal masking
|
||||
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||
3) Support for different sequence lengths for q and k
|
||||
4) Nested tensor API currently does not support dropout or bias.
|
||||
|
||||
Not currently supported:
|
||||
|
||||
1) Non power of two head dims
|
||||
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_dtype: tl.constexpr = torch.float16
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv_fn(x, y):
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
@triton.jit
|
||||
def max_fn(x, y):
|
||||
return tl.math.max(x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
ms = tl.arange(0, m)
|
||||
ns = tl.arange(0, n)
|
||||
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_offsets = dropout_offsets(
|
||||
philox_seed, philox_offset, dropout_p, m, n, stride
|
||||
).to(tl.uint32)
|
||||
# TODO: use tl.randint for better performance
|
||||
return tl.rand(philox_seed, rng_offsets)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
|
||||
rng_keep = rng_output > dropout_p
|
||||
return rng_keep
|
||||
|
||||
|
||||
@triton.jit
|
||||
def load_fn(block_ptr, first, second, pad):
|
||||
if first and second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||
elif first:
|
||||
tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad)
|
||||
elif second:
|
||||
tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad)
|
||||
else:
|
||||
tensor = tl.load(block_ptr)
|
||||
return tensor
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
actual_seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
OFFS_M: tl.constexpr,
|
||||
OFFS_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
MASK_STEPS: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
PADDED_HEAD: tl.constexpr,
|
||||
):
|
||||
# loop over k, v, and update accumulator
|
||||
for start_n in range(block_min, block_max, BLOCK_N):
|
||||
# For padded blocks, we will overrun the tensor size if
|
||||
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||
k = load_fn(
|
||||
K_block_ptr,
|
||||
PADDED_HEAD,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
"zero",
|
||||
)
|
||||
if PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
# We start from end of seqlen_k so only the first iteration would need
|
||||
# to be checked for padding if it is not a multiple of block_n
|
||||
# TODO: This can be optimized to only be true for the padded block.
|
||||
if MASK_STEPS: # noqa: SIM102
|
||||
# If this is the last block / iteration, we want to
|
||||
# mask if the sequence length is not a multiple of block size
|
||||
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||
# check if this masking works for that case.
|
||||
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||
boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32)
|
||||
size_n = start_n + OFFS_N[None, :]
|
||||
mask = size_n < boundary_m[:, None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
causal_boundary = start_n + offs_n_causal
|
||||
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||
# -- compute qk ----
|
||||
qk += tl.dot(q, k)
|
||||
if bias_ptr is not None:
|
||||
bias = load_fn(
|
||||
bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero"
|
||||
)
|
||||
# While bias is added after multiplying qk with sm_scale, our
|
||||
# optimization to use 2^x instead of e^x results in an additional
|
||||
# scale factor of log2(e) which we must also multiply the bias with.
|
||||
qk += bias * 1.44269504089
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
qk = qk - m_ij[:, None]
|
||||
p = tl.math.exp2(qk)
|
||||
|
||||
# CAVEAT: Must update l_ij before applying dropout
|
||||
l_ij = tl.sum(p, 1)
|
||||
if ENABLE_DROPOUT:
|
||||
philox_offset = (
|
||||
batch_philox_offset
|
||||
+ start_m * BLOCK_M * actual_seqlen_k
|
||||
+ start_n
|
||||
- BLOCK_N
|
||||
)
|
||||
keep = dropout_mask(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
dropout_p,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
actual_seqlen_k,
|
||||
)
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
p = tl.where(keep, p, 0.0)
|
||||
elif RETURN_ENCODED_SOFTMAX:
|
||||
tl.store(
|
||||
encoded_softmax_block_ptr,
|
||||
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||
)
|
||||
# -- update output accumulator --
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
acc = acc * alpha[:, None]
|
||||
if not PRE_LOAD_V:
|
||||
v = load_fn(
|
||||
V_block_ptr,
|
||||
MASK_STEPS and (n_extra_tokens != 0),
|
||||
PADDED_HEAD,
|
||||
"zero",
|
||||
)
|
||||
# -- update m_i and l_i
|
||||
l_i = l_i * alpha + l_ij
|
||||
# update m_i and l_i
|
||||
m_i = m_ij
|
||||
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
return acc, l_i, m_i
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 256,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 2,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 128,
|
||||
"waves_per_eu": 2,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 256,
|
||||
"BLOCK_N": 128,
|
||||
"waves_per_eu": 2,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 3,
|
||||
"PRE_LOAD_V": True,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 3,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 64,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 4,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 32,
|
||||
"BLOCK_N": 32,
|
||||
"waves_per_eu": 4,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=8,
|
||||
),
|
||||
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 16,
|
||||
"BLOCK_N": 16,
|
||||
"waves_per_eu": 1,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_M": 128,
|
||||
"BLOCK_N": 64,
|
||||
"waves_per_eu": 1,
|
||||
"PRE_LOAD_V": False,
|
||||
},
|
||||
num_stages=1,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||
)
|
||||
@triton.jit
|
||||
def attn_fwd(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
bias,
|
||||
sm_scale,
|
||||
L,
|
||||
Out,
|
||||
stride_qz,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_vn,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
stride_on,
|
||||
stride_bz,
|
||||
stride_bh,
|
||||
stride_bm,
|
||||
stride_bn,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
philox_offset_base,
|
||||
encoded_softmax,
|
||||
HQ: tl.constexpr,
|
||||
HK: tl.constexpr,
|
||||
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||
MAX_SEQLENS_Q: tl.constexpr,
|
||||
MAX_SEQLENS_K: tl.constexpr,
|
||||
VARLEN: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
PRE_LOAD_V: tl.constexpr,
|
||||
BIAS_TYPE: tl.constexpr,
|
||||
ENABLE_DROPOUT: tl.constexpr,
|
||||
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_h_q = tl.program_id(1)
|
||||
off_z = tl.program_id(2)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
if VARLEN:
|
||||
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||
# small for all start_m so for those we return early.
|
||||
if start_m * BLOCK_M > seqlen_q:
|
||||
return
|
||||
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||
else:
|
||||
cu_seqlens_q_start = 0
|
||||
cu_seqlens_k_start = 0
|
||||
seqlen_q = MAX_SEQLENS_Q
|
||||
seqlen_k = MAX_SEQLENS_K
|
||||
|
||||
# Now we compute whether we need to exit early due to causal masking.
|
||||
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||
# are completely masked, resulting in 0s written to the output, and
|
||||
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||
# This block of code determines what N is, and if this WG is operating
|
||||
# on those M rows.
|
||||
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||
if IS_CAUSAL:
|
||||
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||
# the causal mask boundary is bottom right aligned, and ends at either
|
||||
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||
# matrix
|
||||
n_blocks_seqlen = cdiv_fn(
|
||||
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N
|
||||
)
|
||||
# This is what adjusts the block_max for the current WG, only
|
||||
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||
# part of the blocks that are all 0. We exit early.
|
||||
if n_blocks <= 0:
|
||||
o_offset = (
|
||||
off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||
# We still need to write 0s to the result
|
||||
# tl.store(O_block_ptr,
|
||||
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||
# + offs_m
|
||||
# We store inf to LSE, not -inf because in the bwd pass,
|
||||
# we subtract this
|
||||
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||
# for these masked blocks.
|
||||
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||
# tl.store(l_ptrs, l)
|
||||
# TODO: Should dropout and return encoded softmax be handled here?
|
||||
return
|
||||
|
||||
# If MQA / GQA, set the K and V head offsets appropriately.
|
||||
GROUP_SIZE: tl.constexpr = HQ // HK
|
||||
if GROUP_SIZE != 1:
|
||||
off_h_k = off_h_q // GROUP_SIZE
|
||||
else:
|
||||
off_h_k = off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
elif seqlen_k % BLOCK_N:
|
||||
n_extra_tokens = seqlen_k % BLOCK_N
|
||||
PADDED_HEAD: tl.constexpr = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||
|
||||
# Compute pointers for all the tensors used in this kernel.
|
||||
q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + q_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
if BIAS_TYPE != 0:
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + off_h_q * stride_bh,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
bias_ptr = None
|
||||
if ENABLE_DROPOUT:
|
||||
batch_philox_offset = (
|
||||
philox_offset_base + (off_z * HQ + off_h_q) * seqlen_q * seqlen_k
|
||||
)
|
||||
else:
|
||||
batch_philox_offset = 0
|
||||
# We can ask to return the dropout mask without actually doing any dropout.
|
||||
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||
# valid.
|
||||
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||
shape=(seqlen_q, seqlen_k),
|
||||
strides=(seqlen_k, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
encoded_softmax_block_ptr = 0
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||
# have native e^x support in HW.
|
||||
qk_scale = sm_scale * 1.44269504089
|
||||
# Q is loaded once at the beginning and shared by all N blocks.
|
||||
q = load_fn(Q_block_ptr, True, PADDED_HEAD, "zero")
|
||||
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||
|
||||
# Here we compute how many full and masked blocks we have.
|
||||
padded_block_k = n_extra_tokens != 0
|
||||
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||
if IS_CAUSAL:
|
||||
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||
# Additionally there might be one more due to dissimilar seqlens.
|
||||
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||
else:
|
||||
# Padding on Q does not need to be masked in the FA loop.
|
||||
masked_blocks = padded_block_k
|
||||
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||
# block. In this case we might exceed n_blocks so pick the min.
|
||||
masked_blocks = min(masked_blocks, n_blocks)
|
||||
n_full_blocks = n_blocks - masked_blocks
|
||||
block_min = 0
|
||||
block_max = n_blocks * BLOCK_N
|
||||
# Compute for full blocks. Here we set causal to false regardless of its
|
||||
# value because there is no masking. Similarly we do not need padding.
|
||||
if n_full_blocks > 0:
|
||||
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||
block_min,
|
||||
block_max,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
bias_ptr,
|
||||
# IS_CAUSAL, ....
|
||||
False,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
False,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
PADDED_HEAD,
|
||||
)
|
||||
block_min = block_max
|
||||
block_max = n_blocks * BLOCK_N
|
||||
|
||||
tl.debug_barrier()
|
||||
# Remaining blocks, if any, are full / not masked.
|
||||
if masked_blocks > 0:
|
||||
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||
if bias_ptr is not None:
|
||||
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||
if RETURN_ENCODED_SOFTMAX:
|
||||
encoded_softmax_block_ptr = tl.advance(
|
||||
encoded_softmax_block_ptr, (0, n_full_blocks)
|
||||
)
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc,
|
||||
l_i,
|
||||
m_i,
|
||||
q,
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
start_m,
|
||||
seqlen_k,
|
||||
dropout_p,
|
||||
philox_seed,
|
||||
batch_philox_offset,
|
||||
encoded_softmax_block_ptr,
|
||||
block_min,
|
||||
block_max,
|
||||
offs_n_causal,
|
||||
masked_blocks,
|
||||
n_extra_tokens,
|
||||
bias_ptr,
|
||||
IS_CAUSAL,
|
||||
BLOCK_M,
|
||||
BLOCK_DMODEL,
|
||||
BLOCK_N,
|
||||
offs_m,
|
||||
offs_n,
|
||||
# _, MASK_STEPS, ...
|
||||
PRE_LOAD_V,
|
||||
True,
|
||||
ENABLE_DROPOUT,
|
||||
RETURN_ENCODED_SOFTMAX,
|
||||
PADDED_HEAD,
|
||||
)
|
||||
# epilogue
|
||||
acc = acc / l_i[:, None]
|
||||
if ENABLE_DROPOUT:
|
||||
acc = acc / (1 - dropout_p)
|
||||
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||
# then we have one block with a row of all NaNs which come from computing
|
||||
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||
end_m_idx = (start_m + 1) * BLOCK_M
|
||||
start_m_idx = start_m * BLOCK_M
|
||||
causal_start_idx = seqlen_q - seqlen_k
|
||||
acc = acc.to(Out.type.element_ty)
|
||||
if IS_CAUSAL: # noqa: SIM102
|
||||
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||
out_mask_boundary = tl.full(
|
||||
(BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32
|
||||
)
|
||||
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||
out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :]
|
||||
z = 0.0
|
||||
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||
# write back LSE
|
||||
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||
# few rows. This is only true for the last M block. For others,
|
||||
# overflow_size will be -ve
|
||||
# overflow_size = end_m_idx - seqlen_q
|
||||
# if overflow_size > 0:
|
||||
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||
# # This is a > check because mask being 0 blocks the store.
|
||||
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||
# else:
|
||||
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
|
||||
# write back O
|
||||
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + o_offset,
|
||||
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
# Need boundary check on this to make sure the padding from the
|
||||
# Q and KV tensors in both dims are not part of what we store back.
|
||||
# TODO: Do the boundary check optionally.
|
||||
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||
|
||||
|
||||
def check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
max_seqlens=None,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
):
|
||||
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||
if varlen:
|
||||
assert q.dim() == 3
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
assert cu_seqlens_q is not None
|
||||
assert cu_seqlens_k is not None
|
||||
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||
else:
|
||||
assert q.dim() == 4
|
||||
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||
_, nheads_k, seqlen_k, _ = k.shape
|
||||
assert max_seqlens > 0
|
||||
assert k.shape == v.shape
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
# TODO: Fix assert to check head size <=256 once supported
|
||||
assert head_size <= 128
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlens_q,
|
||||
max_seqlens_k,
|
||||
causal=False,
|
||||
sm_scale=1.0,
|
||||
bias=None,
|
||||
):
|
||||
if o is None:
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
|
||||
check_args(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
varlen=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
)
|
||||
if True: # varlen
|
||||
total_q, nheads_q, head_size = q.shape
|
||||
total_k, nheads_k, _ = k.shape
|
||||
batch = len(cu_seqlens_q) - 1
|
||||
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||
else:
|
||||
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
padded_d_model = 1 << (head_size - 1).bit_length()
|
||||
padded_d_model = max(padded_d_model, 16)
|
||||
|
||||
def grid(META):
|
||||
return triton.cdiv(max_seqlens_q, META["BLOCK_M"]), nheads_q, batch
|
||||
|
||||
encoded_softmax = None
|
||||
|
||||
# Seed the RNG so we get reproducible results for testing.
|
||||
philox_seed = 0x1BF52
|
||||
philox_offset = 0x1D4B42
|
||||
|
||||
if bias is not None:
|
||||
bias_strides = (
|
||||
bias.stride(0),
|
||||
bias.stride(1),
|
||||
bias.stride(2),
|
||||
bias.stride(3),
|
||||
)
|
||||
else:
|
||||
bias_strides = (0, 0, 0, 0)
|
||||
|
||||
attn_fwd[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
sm_scale,
|
||||
None,
|
||||
o,
|
||||
*q_strides,
|
||||
*k_strides,
|
||||
*v_strides,
|
||||
*o_strides,
|
||||
*bias_strides,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
dropout_p=0.0,
|
||||
philox_seed=philox_seed,
|
||||
philox_offset_base=philox_offset,
|
||||
encoded_softmax=encoded_softmax,
|
||||
HQ=nheads_q,
|
||||
HK=nheads_k,
|
||||
ACTUAL_BLOCK_DMODEL=head_size,
|
||||
MAX_SEQLENS_Q=max_seqlens_q,
|
||||
MAX_SEQLENS_K=max_seqlens_k,
|
||||
IS_CAUSAL=causal,
|
||||
VARLEN=True,
|
||||
BLOCK_DMODEL=padded_d_model,
|
||||
BIAS_TYPE=0 if bias is None else 1,
|
||||
ENABLE_DROPOUT=False,
|
||||
RETURN_ENCODED_SOFTMAX=False,
|
||||
)
|
||||
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = head_size
|
||||
ctx.causal = causal
|
||||
ctx.dropout_p = 0.0
|
||||
ctx.philox_seed = philox_seed
|
||||
ctx.philox_offset = philox_offset
|
||||
ctx.encoded_softmax = encoded_softmax
|
||||
ctx.return_encoded_softmax = False
|
||||
return o, encoded_softmax
|
||||
|
||||
|
||||
triton_attention = _attention.apply
|
@ -1,251 +0,0 @@
|
||||
from typing import Optional
|
||||
from contextvars import ContextVar
|
||||
from contextlib import contextmanager
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
prefill_state: ContextVar[flashinfer.BatchPrefillWithRaggedKVCacheWrapper] = ContextVar(
|
||||
"prefill_state"
|
||||
)
|
||||
|
||||
prefill_with_paged_kv_state: ContextVar[
|
||||
flashinfer.BatchPrefillWithPagedKVCacheWrapper
|
||||
] = ContextVar("prefill_with_paged_kv_state")
|
||||
|
||||
decode_state: ContextVar[flashinfer.BatchDecodeWithPagedKVCacheWrapper] = ContextVar(
|
||||
"decode_state"
|
||||
)
|
||||
|
||||
workspace: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_workspace(device):
|
||||
"""Get shared flashinfer workspace."""
|
||||
global workspace
|
||||
if workspace is None:
|
||||
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device)
|
||||
return workspace
|
||||
|
||||
|
||||
def create_prefill_with_paged_kv_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a prefill state that uses the KV cache."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_with_paged_kv_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithPagedKVCacheWrapper,
|
||||
block_tables: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
indptr = torch.zeros(
|
||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||
)
|
||||
# Round up to page size and then calculate the cumulative sum to get
|
||||
# the indices into the block table.
|
||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||
indptr[1:].cumsum_(-1)
|
||||
|
||||
# Get the lengths of the last page in a block.
|
||||
if page_size == 1:
|
||||
last_page_len = torch.ones(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
else:
|
||||
last_page_len = torch.empty(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
torch.sub(input_lengths, 1, out=last_page_len)
|
||||
last_page_len.remainder_(page_size)
|
||||
last_page_len += 1
|
||||
|
||||
token = prefill_with_paged_kv_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
paged_kv_indptr=indptr,
|
||||
paged_kv_indices=block_tables,
|
||||
paged_kv_last_page_len=last_page_len,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
page_size=page_size,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_with_paged_kv_state.reset(token)
|
||||
|
||||
|
||||
def create_prefill_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a prefill state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
return flashinfer.BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffer, kv_layout="NHD", use_cuda_graph=False
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_prefill_state(
|
||||
*,
|
||||
state: flashinfer.BatchPrefillWithRaggedKVCacheWrapper,
|
||||
cu_seqlens: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer prefill state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`attention` function while the context manager is active.
|
||||
"""
|
||||
|
||||
token = prefill_state.set(state)
|
||||
try:
|
||||
state.begin_forward(
|
||||
qo_indptr=cu_seqlens,
|
||||
kv_indptr=cu_seqlens,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
prefill_state.reset(token)
|
||||
|
||||
|
||||
def create_decode_state(
|
||||
*,
|
||||
device: torch.device,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
"""Create a decode state."""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=False,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
def create_decode_state_cuda_graphs(
|
||||
*,
|
||||
device: torch.device,
|
||||
block_tables: torch.Tensor,
|
||||
block_tables_ptr: torch.Tensor,
|
||||
last_page_len: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
):
|
||||
"""
|
||||
Create a decode state for use with CUDA Graphs. `block_tables`,
|
||||
`block_tables_ptr`, and `last_page_len` are used in CUDA Graphs and are
|
||||
therefore stored as part of the state.
|
||||
"""
|
||||
workspace_buffer = get_workspace(device)
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout="NHD",
|
||||
use_cuda_graph=True,
|
||||
paged_kv_indices_buffer=block_tables,
|
||||
paged_kv_indptr_buffer=block_tables_ptr,
|
||||
paged_kv_last_page_len_buffer=last_page_len,
|
||||
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
|
||||
use_tensor_cores=num_groups not in [1, 2, 4, 8],
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def use_decode_state(
|
||||
*,
|
||||
state: flashinfer.BatchDecodeWithPagedKVCacheWrapper,
|
||||
input_lengths: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
window_left: int,
|
||||
):
|
||||
"""
|
||||
Context manager to set the active flashinfer decoding state to the given
|
||||
`state` and parameters. This state will be used by all calls to the
|
||||
`paged_attention` function while the context manager is active.
|
||||
"""
|
||||
indptr = torch.zeros(
|
||||
input_lengths.shape[0] + 1, device=input_lengths.device, dtype=torch.int32
|
||||
)
|
||||
# Round up to page size and then calculate the cumulative sum to get
|
||||
# the indices into the block table.
|
||||
torch.add(input_lengths, page_size - 1, out=indptr[1:])
|
||||
indptr[1:].div_(page_size, rounding_mode="floor")
|
||||
indptr[1:].cumsum_(-1)
|
||||
|
||||
# Get the lengths of the last page in a block.
|
||||
last_page_len = torch.empty(
|
||||
input_lengths.shape[0], dtype=torch.int32, device=input_lengths.device
|
||||
)
|
||||
torch.sub(input_lengths, 1, out=last_page_len)
|
||||
last_page_len.remainder_(page_size)
|
||||
last_page_len += 1
|
||||
|
||||
token = decode_state.set(state)
|
||||
|
||||
try:
|
||||
state.begin_forward(
|
||||
indptr=indptr,
|
||||
indices=block_tables,
|
||||
last_page_len=last_page_len,
|
||||
num_qo_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_size,
|
||||
page_size=page_size,
|
||||
data_type=dtype,
|
||||
q_data_type=dtype,
|
||||
window_left=window_left,
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
state.end_forward()
|
||||
if token is not None:
|
||||
decode_state.reset(token)
|
@ -0,0 +1,97 @@
|
||||
import torch
|
||||
from text_generation_server.layers.attention import Seqlen, HPUPagedAttentionMetadata
|
||||
from typing import Optional
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, KVScales
|
||||
from vllm_hpu_extension import ops
|
||||
from vllm_hpu_extension.utils import Matmul
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
import os
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
|
||||
|
||||
def fetch_from_cache(cache, blocks):
|
||||
if os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() == "true":
|
||||
return cache[: blocks.size(0)]
|
||||
else:
|
||||
return cache.index_select(0, blocks)
|
||||
|
||||
|
||||
def attention(
|
||||
*,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_scales: KVScales,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
bs = seqlen.input_lengths.shape[0]
|
||||
_, head_num, head_size = query.shape
|
||||
_, kv_head_num, head_size = key.shape
|
||||
query = query.view(bs, -1, head_num, head_size).transpose(1, 2)
|
||||
key = key.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||
value = value.view(bs, -1, kv_head_num, head_size).transpose(1, 2)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=softmax_scale,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=seqlen.input_lengths,
|
||||
padding_side="left",
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
return attn_output
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
*,
|
||||
kv_scales: KVScales,
|
||||
softcap: Optional[float] = None,
|
||||
hpu_attention_meta: HPUPagedAttentionMetadata,
|
||||
):
|
||||
batch_size, head_num, head_size = query.shape
|
||||
output = ops.flat_pa(
|
||||
query=query,
|
||||
key_cache=kv_cache.key,
|
||||
value_cache=kv_cache.value,
|
||||
block_list=hpu_attention_meta.block_list,
|
||||
block_mapping=hpu_attention_meta.block_mapping,
|
||||
block_bias=hpu_attention_meta.attn_bias,
|
||||
block_scales=hpu_attention_meta.block_scales,
|
||||
block_groups=hpu_attention_meta.block_groups,
|
||||
scale=softmax_scale,
|
||||
matmul_qk_op=Matmul(),
|
||||
matmul_av_op=Matmul(),
|
||||
batch2block_matmul_op=Matmul(),
|
||||
block2batch_matmul_op=Matmul(),
|
||||
keys_fetch_func=fetch_from_cache,
|
||||
values_fetch_func=fetch_from_cache,
|
||||
)
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, head_num, head_size)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"SUPPORTS_WINDOWING",
|
||||
"attention",
|
||||
"paged_attention",
|
||||
]
|
@ -1,82 +0,0 @@
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import torch
|
||||
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from typing import Optional
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale,
|
||||
window_size_left=-1,
|
||||
causal=True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
ipex.llm.functional.varlen_attention(
|
||||
q.contiguous() if q.device.type == "xpu" else q,
|
||||
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
|
||||
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_q,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slots
|
||||
)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
out = torch.empty_like(query)
|
||||
ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
seqlen.input_lengths,
|
||||
BLOCK_SIZE,
|
||||
max_s,
|
||||
None,
|
||||
)
|
||||
return out
|
@ -0,0 +1,141 @@
|
||||
from typing import Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
from text_generation_server.models.globals import BLOCK_SIZE
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVScales:
|
||||
"""
|
||||
Key-value scales for FP8 KV cache.
|
||||
|
||||
This data class stores key and value scales both as a GPU tensor and
|
||||
as a GPU float. This inconvenience is necessary because some functions
|
||||
(e.g. scaling kernels) take scales as a GPU tensor, whereas others
|
||||
(e.g. flashinfer) take scales as a CPU scalar.
|
||||
"""
|
||||
|
||||
key_scale: torch.Tensor
|
||||
value_scale: torch.Tensor
|
||||
key_scale_cpu: float = field(init=False)
|
||||
value_scale_cpu: float = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.key_scale.numel() != 1 or self.value_scale.numel() != 1:
|
||||
raise ValueError("Key and value scales must be scalar tensors.")
|
||||
|
||||
self.key_scale_cpu = self.key_scale.item()
|
||||
self.value_scale_cpu = self.value_scale.item()
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""
|
||||
Key-value cache for attention layers.
|
||||
"""
|
||||
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_blocks: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Construct the key-value cache for a layer."""
|
||||
## TODO FP8 kv cache support
|
||||
|
||||
self.kv_cache = (
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.zeros(
|
||||
(num_blocks, BLOCK_SIZE, num_heads, head_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""Get the data type of the cache."""
|
||||
return self.kv_cache[0].dtype
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
"""Get the key cache."""
|
||||
|
||||
return self.kv_cache[0]
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
"""Get the value cache."""
|
||||
|
||||
return self.kv_cache[1]
|
||||
|
||||
def store(
|
||||
self,
|
||||
*,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
kv_scales: KVScales,
|
||||
):
|
||||
"""Store the key and value at the given slots."""
|
||||
## TODO FP8 kv cache support
|
||||
|
||||
key_cache = self.kv_cache[0]
|
||||
value_cache = self.kv_cache[1]
|
||||
|
||||
paged_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slots,
|
||||
kv_scales.key_scale_cpu,
|
||||
kv_scales.value_scale_cpu,
|
||||
)
|
||||
|
||||
|
||||
def paged_reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
|
||||
from vllm_hpu_extension import cache_ops
|
||||
|
||||
block_idx = slots // BLOCK_SIZE
|
||||
block_offset = slots % BLOCK_SIZE
|
||||
cache_ops.insert_or_update_cache(key, key_cache, block_idx, block_offset)
|
||||
cache_ops.insert_or_update_cache(value, value_cache, block_idx, block_offset)
|
||||
|
||||
|
||||
def get_kv_scales(weights: Weights, prefix: str) -> KVScales:
|
||||
"""Load KV cache scales."""
|
||||
|
||||
key_scale = torch.tensor(1.0, dtype=torch.float32, device=weights.device)
|
||||
value_scale = key_scale
|
||||
if weights.has_tensor(f"{prefix}.k_scale") and weights.has_tensor(
|
||||
f"{prefix}.v_scale"
|
||||
):
|
||||
key_scale = weights.get_tensor(f"{prefix}.k_scale", to_dtype=False).float()
|
||||
value_scale = weights.get_tensor(f"{prefix}.v_scale", to_dtype=False).float()
|
||||
elif weights.has_tensor(f"{prefix}.kv_scale"):
|
||||
# Fall back to older more coarse-grained scale when available.
|
||||
key_scale = weights.get_tensor(f"{prefix}.kv_scale").float()
|
||||
value_scale = key_scale
|
||||
|
||||
return KVScales(key_scale=key_scale, value_scale=value_scale)
|
@ -1,308 +0,0 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.utils.log import log_master
|
||||
from loguru import logger
|
||||
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_sm75 = major == 7 and minor == 5
|
||||
|
||||
_PARTITION_SIZE_V1V2 = 512
|
||||
_PARTITION_SIZE_CUSTOM = 256
|
||||
|
||||
use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
|
||||
ENGINE = "triton" if use_triton else "ck"
|
||||
|
||||
PREFILL_IN_KV_CACHE = False
|
||||
|
||||
use_rocm_custom_paged_attn = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
|
||||
try:
|
||||
if use_rocm_custom_paged_attn:
|
||||
from vllm._custom_C import paged_attention_custom
|
||||
except ImportError as e:
|
||||
log_master(
|
||||
logger.info,
|
||||
f"Custom Paged Attention not available. Complete error: {e}",
|
||||
)
|
||||
use_rocm_custom_paged_attn = False
|
||||
|
||||
try:
|
||||
import vllm._custom_ops as ops
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
|
||||
)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
):
|
||||
if ATTENTION == "flashdecoding":
|
||||
shape = key_cache.shape
|
||||
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
|
||||
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
|
||||
else:
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
kv_head_mapping: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
|
||||
# Copyright 2023 The vLLM team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
if softcap is not None:
|
||||
raise RuntimeError("Paged attention doesn't support softcapping")
|
||||
|
||||
# value_cache => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
|
||||
num_kv_heads = key_cache.shape[1]
|
||||
gqa_ratio = num_heads // num_kv_heads
|
||||
use_custom = (
|
||||
use_rocm_custom_paged_attn
|
||||
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
|
||||
and (head_size == 128 or head_size == 64)
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_s <= 32768
|
||||
)
|
||||
|
||||
if not use_custom:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
|
||||
else:
|
||||
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM
|
||||
|
||||
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
|
||||
input_lengths = seqlen.input_lengths
|
||||
|
||||
out = torch.empty_like(query)
|
||||
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
use_v1 = (
|
||||
max_s <= 8192
|
||||
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
|
||||
and not use_custom
|
||||
)
|
||||
if use_v1:
|
||||
ops.paged_attention_v1(
|
||||
out,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=out.dtype,
|
||||
device=out.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=out.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
if not use_custom:
|
||||
ops.paged_attention_v2(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
kv_head_mapping,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
1.0,
|
||||
)
|
||||
else:
|
||||
paged_attention_custom(
|
||||
out,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
None,
|
||||
"auto",
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if ENGINE != "triton":
|
||||
try:
|
||||
import flash_attn_2_cuda
|
||||
|
||||
log_master(
|
||||
logger.info,
|
||||
"ROCm: using Flash Attention 2 Composable Kernel implementation.",
|
||||
)
|
||||
except ImportError as e:
|
||||
if major >= 8:
|
||||
architecture_suffix = f"-{SYSTEM}"
|
||||
raise ImportError(
|
||||
"Flash Attention V2 is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
|
||||
)
|
||||
elif is_sm75:
|
||||
raise ImportError(
|
||||
"Flash Attention is not installed.\n"
|
||||
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
|
||||
"or install flash attention with `cd server && make install install-flash-attention`"
|
||||
) from e
|
||||
else:
|
||||
for idx in range(torch.cuda.device_count()):
|
||||
name = torch.cuda.get_device_name(idx)
|
||||
if "MI210" not in name and "MI250" not in name:
|
||||
raise ImportError(
|
||||
f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention"
|
||||
)
|
||||
raise ImportError(
|
||||
f"AMD GPU with ROCm capability {major} {minor} is not supported"
|
||||
) from e
|
||||
|
||||
|
||||
SUPPORTS_WINDOWING = False
|
||||
if ENGINE == "ck":
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: float = 0.0,
|
||||
):
|
||||
if window_size_left <= 0 and window_size_left != -1:
|
||||
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
0.0,
|
||||
softmax_scale,
|
||||
False,
|
||||
causal,
|
||||
window_size_left,
|
||||
0,
|
||||
softcap,
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
|
||||
elif ENGINE == "triton":
|
||||
from .flash_attn_triton import triton_attention
|
||||
|
||||
def attention(
|
||||
q,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
block_tables: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
window_size_left: int = -1,
|
||||
causal: bool = True,
|
||||
softcap: Optional[float] = None,
|
||||
):
|
||||
if softcap is not None:
|
||||
raise NotImplementedError("softcap is only available with CK flash attn")
|
||||
|
||||
out = torch.empty_like(q)
|
||||
|
||||
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
|
||||
output, _ = triton_attention(
|
||||
q,
|
||||
key_cache,
|
||||
value_cache,
|
||||
out,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.cu_seqlen_q,
|
||||
seqlen.max_q,
|
||||
seqlen.max_k,
|
||||
causal,
|
||||
softmax_scale,
|
||||
)
|
||||
return output
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unknown attention engine {ENGINE}")
|
@ -0,0 +1,3 @@
|
||||
from .hpu import WQLinear
|
||||
|
||||
__all__ = ["WQLinear"]
|
@ -0,0 +1,3 @@
|
||||
from .loader import CompressedTensorsLoader
|
||||
|
||||
__all__ = ["CompressedTensorsLoader"]
|
@ -0,0 +1,196 @@
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from compressed_tensors import QuantizationConfig, QuantizationStatus
|
||||
from compressed_tensors.config import CompressionFormat
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationScheme,
|
||||
QuantizationType,
|
||||
find_name_or_class_matches,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import ValidationError
|
||||
from torch import nn
|
||||
|
||||
from text_generation_server.layers.compressed_tensors.w8an_fp import W8ANFpLoader
|
||||
from text_generation_server.layers.compressed_tensors.w8a8_int import W8A8IntLoader
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int_24 import (
|
||||
WNA16Int24Loader,
|
||||
)
|
||||
from text_generation_server.layers.compressed_tensors.wna16_int import WNA16IntLoader
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
WeightsLoader,
|
||||
)
|
||||
|
||||
# compressed-tensors can match modules as quantization targets. However,
|
||||
# they need to be objects rather than classes or class names. Since we
|
||||
# need to match `Linear` targets, make an instance that can be re-used.
|
||||
_EMPTY_LINEAR: nn.Module = nn.Linear(0, 0)
|
||||
|
||||
|
||||
class CompressedTensorsLoader(WeightsLoader):
|
||||
"""Loader for checkpoints stored in the compressed-tensors format."""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
quantization_config_raw = config.get("quantization_config")
|
||||
if quantization_config_raw is None:
|
||||
# `compression_config` was renamed to `quantization_config`; support
|
||||
# retained for backward compatibility.
|
||||
quantization_config_raw = config.get("compression_config")
|
||||
if quantization_config_raw is None:
|
||||
raise ValueError(
|
||||
"Checkpoint does not have compressed-tensors configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
quantization_config = QuantizationConfig.model_validate(
|
||||
quantization_config_raw
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError("Cannot parse compressed-tensors configuration") from e
|
||||
|
||||
if quantization_config.quantization_status not in (
|
||||
QuantizationStatus.COMPRESSED,
|
||||
QuantizationStatus.FROZEN,
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model quantization was not finished, status was: {quantization_config.quantization_status}"
|
||||
)
|
||||
|
||||
self.ignore = (
|
||||
quantization_config.ignore if quantization_config.ignore is not None else []
|
||||
)
|
||||
self.loaders = self._get_target_loaders(quantization_config)
|
||||
|
||||
for target, loader in self.loaders.items():
|
||||
log_once(
|
||||
logger.info,
|
||||
f"Using {loader} for compressed-tensors target '{target}'",
|
||||
)
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights(weights, prefix)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: "Weights",
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_col_packed(weights, prefix, block_sizes)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
loader = self._lookup_loader(prefixes[0])
|
||||
return loader.get_multi_weights_col(weights, prefixes, dim)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
loader = self._lookup_loader(prefix)
|
||||
return loader.get_weights_row(weights, prefix)
|
||||
|
||||
def _get_target_loaders(
|
||||
self, quantization_config: QuantizationConfig
|
||||
) -> Dict[str, WeightsLoader]:
|
||||
"""
|
||||
A compressed-tensors checkpoint can use different quantizations
|
||||
for different targets. This method returns a dictionary with a
|
||||
loader per target.
|
||||
"""
|
||||
|
||||
loaders: Dict[str, WeightsLoader] = {}
|
||||
|
||||
format = quantization_config.format
|
||||
|
||||
for group_name, group in quantization_config.config_groups.items():
|
||||
# The group configuration can be a string, but does that ever
|
||||
# happen in a serialized quantization config?
|
||||
assert isinstance(group, QuantizationScheme)
|
||||
|
||||
loader = self._create_loader_for_group(format, group_name, group)
|
||||
|
||||
# A quantized parameter group can have multiple targets, add the
|
||||
# loader for all the targets.
|
||||
for target in group.targets:
|
||||
if target in loaders:
|
||||
raise ValueError(
|
||||
f"Target '{target} has multiple configured loaders'"
|
||||
)
|
||||
loaders[target] = loader
|
||||
|
||||
return loaders
|
||||
|
||||
def _create_loader_for_group(
|
||||
self, format: str, group_name: str, group: QuantizationScheme
|
||||
) -> WeightsLoader:
|
||||
"""
|
||||
Find and create a loader for the group with the given quantization
|
||||
scheme.
|
||||
"""
|
||||
# NOTE: we ignore group.output_activations because we don't support
|
||||
# output quantization yet.
|
||||
|
||||
input_activations = group.input_activations
|
||||
weights = group.weights
|
||||
if (
|
||||
format
|
||||
in {
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.naive_quantized.value,
|
||||
}
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.FLOAT
|
||||
and weights.num_bits == 8
|
||||
):
|
||||
# FP W8A8 or W8A16.
|
||||
return W8ANFpLoader(input_activations=input_activations, weights=weights)
|
||||
elif (
|
||||
format == CompressionFormat.pack_quantized.value
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
# INT W4A16 or W8A16 (GPTQ/AWQ-like).
|
||||
return WNA16IntLoader(weights)
|
||||
elif (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits in (4, 8)
|
||||
):
|
||||
return WNA16Int24Loader(weights)
|
||||
elif (
|
||||
format
|
||||
in {
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.naive_quantized.value,
|
||||
}
|
||||
and weights is not None
|
||||
and weights.type == QuantizationType.INT
|
||||
and weights.num_bits == 8
|
||||
):
|
||||
return W8A8IntLoader(input_args=input_activations, weight_args=weights)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Group '{group_name}' has unsupported compressed-tensors configurtion"
|
||||
)
|
||||
|
||||
def _lookup_loader(self, prefix: str) -> WeightsLoader:
|
||||
"""
|
||||
Look up the loader to use for a given parameter name (prefix).
|
||||
"""
|
||||
|
||||
if len(find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.ignore)) > 0:
|
||||
return DefaultWeightsLoader(UnquantizedWeight)
|
||||
|
||||
# We currently only handle linear layers, so unconditionally pass
|
||||
# a `Linear` instance.
|
||||
targets = find_name_or_class_matches(prefix, _EMPTY_LINEAR, self.loaders.keys())
|
||||
if len(targets) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find compressed-tensors target for prefix: {prefix}"
|
||||
)
|
||||
return self.loaders[targets[0]]
|
@ -0,0 +1,239 @@
|
||||
from typing import List, Optional, Union, TypeVar
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import _load_scalar_or_matrix_scale
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
class W8A8IntLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for w8a8 integer compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_args: Optional[QuantizationArgs],
|
||||
weight_args: QuantizationArgs,
|
||||
):
|
||||
if weight_args.type != QuantizationType.INT and weight_args.num_bits != 8:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} only supports w8a8 int checkpoints"
|
||||
)
|
||||
|
||||
if not weight_args.symmetric:
|
||||
raise ValueError("Checkpoints with asymmetric weights are not supported")
|
||||
|
||||
self.load_weight_scale = not weight_args.dynamic
|
||||
|
||||
if input_args is not None:
|
||||
self.input_symmetric = input_args.symmetric
|
||||
|
||||
if not input_args.dynamic:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Forcing dynamic input quantization for compressed_tensors w8a8 int checkpoint (for better accuracy).",
|
||||
)
|
||||
else:
|
||||
self.input_symmetric = True
|
||||
|
||||
def __str__(self) -> str:
|
||||
def scale_to_str(scale):
|
||||
return "static" if scale else "dynamic"
|
||||
|
||||
def symmetric_to_str(symmetric):
|
||||
return "symmetric" if symmetric else "asymmetric"
|
||||
|
||||
return f"{self.__class__.__name__} (w8a8 int, input: dynamic/{symmetric_to_str(self.input_symmetric)}, weight: {scale_to_str(self.load_weight_scale)}/symmetric))"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight", to_dtype=False)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||
)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if weight_scale.numel() > 1:
|
||||
weight_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
weight_scale = weight_scale.reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_dtype=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
w = torch.cat(w, dim=dim)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1, 1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1, to_dtype=False)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(
|
||||
f"{prefix}.weight_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Int8Weight(
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
|
||||
|
||||
OtherT = TypeVar("OtherT")
|
||||
|
||||
|
||||
def _get_tensor_or_else(
|
||||
weights: Weights, prefix: str, other: OtherT
|
||||
) -> Union[torch.Tensor, OtherT]:
|
||||
# Even if a checkpoint uses e.g. zero-points, they can be elided:
|
||||
# https://github.com/neuralmagic/compressed-tensors/blob/db6ccb25b265e8370813ecab5e95714a6728b5a6/src/compressed_tensors/compressors/quantized_compressors/base.py#L105
|
||||
if weights.has_tensor(prefix):
|
||||
return weights.get_tensor(prefix, to_dtype=False)
|
||||
else:
|
||||
return other
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8Weight(Weight):
|
||||
input_symmetric: bool
|
||||
weight: torch.Tensor
|
||||
weight_scale: Optional[torch.Tensor]
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
assert quantization is not None
|
||||
qweight, weight_scale, _ = quantization.scaled_int8_quant(self.weight)
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=qweight,
|
||||
weight_scale=weight_scale,
|
||||
)
|
||||
else:
|
||||
return W8A8IntLinear(
|
||||
bias=bias,
|
||||
input_symmetric=self.input_symmetric,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
)
|
||||
|
||||
|
||||
class W8A8IntLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
bias: Optional[torch.Tensor],
|
||||
input_symmetric: bool,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
weight_scale = weight_scale.to(torch.float32)
|
||||
|
||||
self.bias = bias
|
||||
self.input_symmetric = input_symmetric
|
||||
# cutlass kernels require transposed weights.
|
||||
self.weight = weight.t()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
if input_symmetric:
|
||||
self.zero_point_adj = None
|
||||
else:
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md#scaledepilogueazp
|
||||
self.zero_point_adj = self.weight.sum(
|
||||
dim=0, keepdim=True, dtype=torch.int32
|
||||
)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
assert quantization is not None
|
||||
|
||||
qinput, input_scale, input_zero_point = quantization.scaled_int8_quant(
|
||||
input=input,
|
||||
scale=None,
|
||||
azp=None,
|
||||
symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
if self.input_symmetric:
|
||||
return quantization.cutlass_scaled_mm(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
scale_b=self.weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=self.bias,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.zero_point_adj is not None
|
||||
and input_scale is not None
|
||||
and (self.input_symmetric or input_zero_point is not None)
|
||||
)
|
||||
|
||||
return quantization.cutlass_scaled_mm_azp(
|
||||
a=qinput,
|
||||
b=self.weight,
|
||||
scale_a=input_scale,
|
||||
scale_b=self.weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
azp_adj=self.zero_point_adj,
|
||||
azp=input_zero_point,
|
||||
bias=self.bias,
|
||||
)
|
@ -0,0 +1,168 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
|
||||
from text_generation_server.layers.fp8 import (
|
||||
Fp8Weight,
|
||||
_load_scalar_or_matrix_scale,
|
||||
)
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class W8ANFpLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W8A8/W8A16 FP compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
input_activations: Optional[QuantizationArgs],
|
||||
weights: QuantizationArgs,
|
||||
):
|
||||
assert weights.type == QuantizationType.FLOAT and weights.num_bits == 8
|
||||
|
||||
# We ignore the `strategy` option which sets the scales to be
|
||||
# per-tensor, per-channel or per-token. What scales are supported
|
||||
# is dependent on the kernels used (e.g. cutlass can do tokenwise,
|
||||
# Torch cannot, and FP8-Marlin does not quantize inputs at all).
|
||||
# So, instead we try to use the best-possible configuration.
|
||||
|
||||
self.load_weight_scale = not weights.dynamic
|
||||
self.load_input_scale = (
|
||||
input_activations is not None and not input_activations.dynamic
|
||||
)
|
||||
self.force_w8a16 = (
|
||||
input_activations is not None and input_activations.num_bits == 16
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
def scale_to_str(scale):
|
||||
return "static" if scale else "dynamic"
|
||||
|
||||
quantization_type = f"W8A{16 if self.force_w8a16 else 8}"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type}, weight: {scale_to_str(self.load_weight_scale)}, input: {scale_to_str(self.load_input_scale)})"
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
w = weights.get_packed_sharded(
|
||||
f"{prefix}.weight", dim=0, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
if weight_scale.numel() > 1:
|
||||
weight_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||
if input_scale.numel() > 1:
|
||||
input_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.input_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||
w = [
|
||||
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||
]
|
||||
shapes = [x.shape for x in w]
|
||||
|
||||
# Concat then send to the device
|
||||
w = torch.cat(w, dim=dim).to(weights.device)
|
||||
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
weight_scale = torch.cat(weight_scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
weight_scale = None
|
||||
if self.load_weight_scale:
|
||||
weight_scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if self.load_input_scale:
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
).reshape(-1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=input_scale,
|
||||
dtype=weights.dtype,
|
||||
force_w8a16=self.force_w8a16,
|
||||
)
|
@ -0,0 +1,188 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering, QuantizationArgs
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.layers.marlin.gptq import repack_gptq_for_marlin
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16IntLoader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT compressed-tensors parameters.
|
||||
"""
|
||||
|
||||
def __init__(self, weights: QuantizationArgs):
|
||||
self.weights = weights
|
||||
self.desc_act = self.weights.actorder == ActivationOrdering.GROUP
|
||||
self.groupsize = (
|
||||
-1 if self.weights.group_size is None else self.weights.group_size
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.weights.num_bits}A16"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
weight_packed = weights.get_tensor(f"{prefix}.weight_packed").t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_tensor(f"{prefix}.weight_g_idx")
|
||||
|
||||
scales = weights.get_tensor(f"{prefix}.weight.scales").t()
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
try:
|
||||
weight_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_packed", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
scales = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
scales = scales.to(dtype=weights.dtype)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = weights.get_packed_sharded(
|
||||
f"{prefix}.qzeros", dim=0, block_sizes=block_sizes
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_tensor(f"{prefix}.g_idx")
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
try:
|
||||
weight_packed = torch.cat(
|
||||
[
|
||||
weights.get_sharded(f"{p}.weight_packed", dim=0).t()
|
||||
for p in prefixes
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load w{self.weights.num_bits}a16 weight, make sure the model is already quantized"
|
||||
)
|
||||
|
||||
scales = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight_scale", dim=0).t() for p in prefixes],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
zero_point = torch.cat(
|
||||
[weights.get_sharded(f"{p}.qzeros", dim=0).t() for p in prefixes], dim=1
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
w = [weights.get_tensor(f"{p}.g_idx") for p in prefixes]
|
||||
for w2 in w[1:]:
|
||||
torch.testing.assert_close(w2, w[0])
|
||||
g_idx = w[0]
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=False,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
log_once(logger.info, "Using GPTQ-Marlin kernels")
|
||||
try:
|
||||
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=1).t()
|
||||
except RuntimeError:
|
||||
raise RuntimeError(
|
||||
f"Cannot load `{self.quantize}` weight, make sure the model is already quantized."
|
||||
)
|
||||
|
||||
zero_point = None
|
||||
if not self.weights.symmetric:
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
zero_point = weights.get_tensor(f"{prefix}.weight_zero_point").t()
|
||||
else:
|
||||
zero_point = weights.get_sharded(
|
||||
f"{prefix}.weight_zero_point", dim=1
|
||||
).t()
|
||||
|
||||
g_idx = None
|
||||
if self.desc_act:
|
||||
g_idx = weights.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if self.desc_act or self.groupsize == -1:
|
||||
scales = weights.get_tensor(f"{prefix}.weight_scale").t()
|
||||
else:
|
||||
scales = weights.get_sharded(f"{prefix}.weight_scale", dim=1).t()
|
||||
|
||||
sharded_in_features = weights.process_group.size() > 1
|
||||
|
||||
return repack_gptq_for_marlin(
|
||||
qweight=weight_packed.contiguous(),
|
||||
scales=scales,
|
||||
qzeros=zero_point,
|
||||
g_idx=g_idx,
|
||||
bits=self.weights.num_bits,
|
||||
desc_act=self.desc_act,
|
||||
groupsize=self.groupsize,
|
||||
quant_method="compressed-tensors",
|
||||
sym=self.weights.symmetric,
|
||||
sharded_infeatures=sharded_in_features,
|
||||
)
|
@ -0,0 +1,101 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationType
|
||||
from text_generation_server.layers.marlin.marlin import GPTQMarlin24Weight
|
||||
from text_generation_server.utils.weights import Weights, WeightsLoader
|
||||
|
||||
|
||||
class WNA16Int24Loader(WeightsLoader):
|
||||
"""
|
||||
Loader for W4A16/W8A16 INT 2:4 sparsity compressed-tensors checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_args: QuantizationArgs):
|
||||
super().__init__()
|
||||
|
||||
if weight_args.type != QuantizationType.INT:
|
||||
raise ValueError(
|
||||
f"{type(self).__name__} only supports wNa8 int checkpoints"
|
||||
)
|
||||
|
||||
if weight_args.strategy == "group" and weight_args.group_size is None:
|
||||
raise ValueError("`group_size` must be set when `actorder` is `group`")
|
||||
|
||||
self.bits = weight_args.num_bits
|
||||
self.group_size = weight_args.group_size
|
||||
|
||||
def __str__(self) -> str:
|
||||
quantization_type = f"W{self.bits}A16 2:4 sparsity"
|
||||
|
||||
return f"{self.__class__.__name__} ({quantization_type})"
|
||||
|
||||
def get_weights(self, weights: Weights, prefix: str):
|
||||
"""
|
||||
Get weights at the given prefix and apply without tensor paralllism.
|
||||
"""
|
||||
weight_packed = weights.get_tensor(f"{prefix}.weight_packed")
|
||||
meta = weights.get_tensor(f"{prefix}.meta")
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_col_packed(
|
||||
self,
|
||||
weights: Weights,
|
||||
prefix: str,
|
||||
block_sizes: Union[int, List[int]],
|
||||
):
|
||||
weight_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
meta = weights.get_packed_sharded(
|
||||
f"{prefix}.meta", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
scale_packed = weights.get_packed_sharded(
|
||||
f"{prefix}.scale_packed", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_multi_weights_col(self, weights: Weights, prefixes: List[str], dim: int):
|
||||
weight_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.weight_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
meta = torch.cat(
|
||||
[weights.get_sharded(f"{p}.meta", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
scale_packed = torch.cat(
|
||||
[weights.get_sharded(f"{p}.scale_packed", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
||||
|
||||
def get_weights_row(self, weights: Weights, prefix: str):
|
||||
weight_packed = weights.get_sharded(f"{prefix}.weight_packed", dim=0)
|
||||
meta = weights.get_sharded(f"{prefix}.meta", dim=0)
|
||||
if self.group_size is None:
|
||||
scale_packed = weights.get_tensor(f"{prefix}.scale_packed")
|
||||
else:
|
||||
scale_packed = weights.get_sharded(f"{prefix}.scale_packed", dim=0)
|
||||
|
||||
return GPTQMarlin24Weight(
|
||||
weight_packed=weight_packed,
|
||||
meta=meta,
|
||||
scale_packed=scale_packed,
|
||||
bits=self.bits,
|
||||
)
|
@ -1,43 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from EETQ import quant_weights, w8_a16_gemm
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
|
||||
@dataclass
|
||||
class EETQWeight(UnquantizedWeight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
try:
|
||||
from text_generation_server.layers.eetq import EETQLinear
|
||||
|
||||
return EETQLinear(self.weight, bias)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||
)
|
||||
|
||||
|
||||
class EETQLinear(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
device = weight.device
|
||||
if weight.dtype != torch.float16:
|
||||
weight = weight.to(dtype=torch.float16)
|
||||
weight = torch.t(weight).contiguous().cpu()
|
||||
weight, scale = quant_weights(weight, torch.int8, False)
|
||||
|
||||
self.weight = weight.cuda(device)
|
||||
self.scale = scale.cuda(device)
|
||||
self.bias = bias.cuda(device) if bias is not None else None
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
output = w8_a16_gemm(input, self.weight, self.scale)
|
||||
output = output + self.bias if self.bias is not None else output
|
||||
return output
|
@ -1,102 +1,165 @@
|
||||
import torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, List
|
||||
from typing import Optional, Tuple, Type, Union, List
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import (
|
||||
Weight,
|
||||
WeightsLoader,
|
||||
UnquantizedWeight,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.log import log_master, log_once
|
||||
import importlib.util
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
quantization = None
|
||||
w8a8_block_fp8_matmul = None
|
||||
per_token_group_quant_fp8 = None
|
||||
|
||||
quant_dtype: torch.dtype = torch.float8_e4m3fn
|
||||
|
||||
|
||||
FBGEMM_MM_AVAILABLE = False
|
||||
FBGEMM_DYN_AVAILABLE = False
|
||||
CUTLASS_FP8_AVAILABLE = False
|
||||
|
||||
|
||||
def is_fbgemm_gpu_available():
|
||||
try:
|
||||
return importlib.util.find_spec("fbgemm_gpu.experimental.gen_ai") is not None
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
if is_fbgemm_gpu_available():
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
FBGEMM_MM_AVAILABLE = major == 9
|
||||
FBGEMM_DYN_AVAILABLE = major >= 8
|
||||
else:
|
||||
log_master(logger.warning, "FBGEMM fp8 kernels are not installed.")
|
||||
|
||||
|
||||
def get_fp8_linear() -> torch.nn.Module:
|
||||
def get_fp8_linear(force_w8a16: bool = False) -> Type[torch.nn.Module]:
|
||||
"""
|
||||
Return an FP8 linear `Module` that is compatible with the current system.
|
||||
"""
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
if major == 8:
|
||||
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear
|
||||
|
||||
return GPTQMarlinFP8Linear
|
||||
|
||||
# On other systems let Torch decide if the hardware supports FP8.
|
||||
return Fp8Linear
|
||||
|
||||
|
||||
def fp8_quantize(
|
||||
weight, scale_upper_bound=None, qdtype=torch.float8_e4m3fn, scalar=False
|
||||
):
|
||||
if FBGEMM_DYN_AVAILABLE and not scalar:
|
||||
qweight, scale = torch.ops.fbgemm.quantize_fp8_per_row(
|
||||
weight, bs=None, scale_ub=scale_upper_bound, output_dtype=qdtype
|
||||
)
|
||||
return qweight, scale
|
||||
def normalize_e4m3fn_to_native_float8(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
return weight, weight_scale, input_scale
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor,
|
||||
inv_scale: Union[float, torch.Tensor],
|
||||
dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(dtype)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def requantize_with_max_scale(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
logical_widths: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Max scale to be used for requanitzation.
|
||||
max_w_scale = weight_scale.max().float()
|
||||
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
weight[start:end, :], weight_scale[idx], dtype
|
||||
)
|
||||
weight[start:end, :], max_w_scale_normalized = fp8_quantize(
|
||||
weight_dq, max_w_scale
|
||||
)
|
||||
start = end
|
||||
|
||||
return weight, max_w_scale_normalized
|
||||
|
||||
|
||||
def fp8_quantize(
|
||||
weight: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
scale_upper_bound: Optional[torch.Tensor] = None,
|
||||
qdtype: torch.dtype = torch.float8_e4m3fn,
|
||||
scalar: bool = False,
|
||||
):
|
||||
"""
|
||||
This function returns a reciprocal of the scale, so that a tensor can be unscaled
|
||||
by multiplying it with the returned scale. If a scale is given through the `scale`
|
||||
argument, it must also be a reciprocal (so that scales from an FP8 checkpoint can
|
||||
be used without modification).
|
||||
"""
|
||||
if quantization is not None:
|
||||
shape = weight.shape
|
||||
qweight, scale = quantization.scaled_fp8_quant(
|
||||
weight.reshape(-1, shape[-1]),
|
||||
scale=scale,
|
||||
scale_ub=scale_upper_bound,
|
||||
# TODO: don't do this when we have to use the Torch kernel.
|
||||
use_per_token_if_dynamic=not scalar,
|
||||
)
|
||||
|
||||
return qweight.reshape(shape), scale
|
||||
|
||||
# weight, scale = quant_weights(weight, torch.int8, False)
|
||||
finfo = torch.finfo(qdtype)
|
||||
|
||||
if scale is None:
|
||||
# Calculate the scale as dtype max divided by absmax
|
||||
scale = finfo.max / weight.abs().max().clamp(min=1e-12, max=scale_upper_bound)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
scale = scale.float().reciprocal()
|
||||
else:
|
||||
# Use reciprocal to avoid more expensive division.
|
||||
qweight = (weight * scale.reciprocal()).clamp(min=finfo.min, max=finfo.max)
|
||||
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(qdtype)
|
||||
scale = scale.float().reciprocal()
|
||||
|
||||
return qweight, scale
|
||||
|
||||
|
||||
class HybridFP8UnquantLoader(WeightsLoader):
|
||||
"""Weight loader that loads FP8 and unquantized Torch tensors."""
|
||||
|
||||
def __init__(self, activation_scale_ub: Optional[float], to_fp8: bool):
|
||||
def __init__(
|
||||
self,
|
||||
activation_scale_ub: Optional[float],
|
||||
to_fp8: bool,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
):
|
||||
self.activation_scale_ub = activation_scale_ub
|
||||
self.to_fp8 = to_fp8
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
def get_weights(self, weights: "Weights", prefix: str):
|
||||
w = weights.get_tensor(f"{prefix}.weight")
|
||||
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
if self.weight_block_size is not None:
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale_inv")
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
input_scale = (
|
||||
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
@ -116,6 +179,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
# FP8 branch
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_packed_sharded(
|
||||
f"{prefix}.weight_scale",
|
||||
@ -123,11 +187,25 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
scale = scale.reshape(-1).expand(w.shape[0])
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
input_scale = weights.get_tensor(
|
||||
f"{prefix}.input_scale", to_dtype=False
|
||||
)
|
||||
if input_scale.numel() > 1:
|
||||
input_scale = weights.get_packed_sharded(
|
||||
f"{prefix}.input_scale",
|
||||
dim=0,
|
||||
block_sizes=block_sizes,
|
||||
to_dtype=False,
|
||||
)
|
||||
input_scale = input_scale.reshape(-1).max()
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
@ -148,15 +226,43 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
if self.weight_block_size is not None:
|
||||
scale = [
|
||||
weights.get_sharded(f"{p}.weight_scale_inv", dim=0, to_device=False)
|
||||
for p in prefixes
|
||||
]
|
||||
scale = torch.cat(scale, dim=dim)
|
||||
scale = scale.to(weights.device)
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.weight_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
]
|
||||
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||
|
||||
input_scale = [
|
||||
_load_scalar_or_matrix_scale(weights, f"{p}.input_scale", shape)
|
||||
for p, shape in zip(prefixes, shapes)
|
||||
if weights.has_tensor(f"{p}.input_scale")
|
||||
]
|
||||
assert len(input_scale) == 0 or len(input_scale) == len(prefixes)
|
||||
input_scale = (
|
||||
torch.cat(input_scale, dim=0).reshape(-1).max()
|
||||
if len(input_scale) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
@ -169,16 +275,34 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||
# FP8 branch
|
||||
if w.dtype == torch.float8_e4m3fn:
|
||||
scale = (
|
||||
weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.expand(w.shape[0])
|
||||
)
|
||||
if self.weight_block_size is not None:
|
||||
# XXX: Yes the weights is named scale_inv, but corresponds to scale it seems.
|
||||
scale = weights.get_sharded(f"{prefix}.weight_scale_inv", dim=1)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
||||
|
||||
input_scale = None
|
||||
if weights.has_tensor(f"{prefix}.input_scale"):
|
||||
input_scale = (
|
||||
weights.get_tensor(f"{prefix}.input_scale", to_dtype=False)
|
||||
.reshape(-1)
|
||||
.max()
|
||||
)
|
||||
|
||||
return Fp8Weight(
|
||||
weight=w,
|
||||
weight_scale=scale,
|
||||
input_scale=input_scale,
|
||||
activation_scale_ub=self.activation_scale_ub,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
if self.to_fp8:
|
||||
return Fp8Weight(weight=w, dtype=weights.dtype)
|
||||
@ -191,83 +315,142 @@ class Fp8Weight(Weight):
|
||||
weight: torch.Tensor
|
||||
dtype: torch.dtype
|
||||
weight_scale: Optional[torch.Tensor] = None
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
activation_scale_ub: Optional[float] = None
|
||||
force_w8a16: bool = False
|
||||
weight_block_size: Optional[List[int]] = None
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.weight_scale is None:
|
||||
return get_fp8_linear().from_unquant(self.weight, bias, self.dtype)
|
||||
return get_fp8_linear(force_w8a16=self.force_w8a16).from_unquant(
|
||||
self.weight, bias, self.dtype
|
||||
)
|
||||
# This is not checked by the fbgemm kernels, but they require contiguous
|
||||
# memory. Can be non-contiguous when we e.g. expand from scalars.
|
||||
self.weight_scale = self.weight_scale.contiguous()
|
||||
return get_fp8_linear().from_fp8(
|
||||
self.weight, self.weight_scale, self.activation_scale_ub, bias, self.dtype
|
||||
return get_fp8_linear(force_w8a16=self.force_w8a16).from_fp8(
|
||||
weight=self.weight,
|
||||
scale=self.weight_scale,
|
||||
dtype=self.dtype,
|
||||
bias=bias,
|
||||
input_scale=self.input_scale,
|
||||
scale_upper_bound=self.activation_scale_ub,
|
||||
weight_block_size=self.weight_block_size,
|
||||
)
|
||||
|
||||
|
||||
class Fp8Linear(torch.nn.Module):
|
||||
_device_identity_cache = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qweight,
|
||||
scale,
|
||||
scale_upper_bound,
|
||||
bias,
|
||||
dtype,
|
||||
qweight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
scale_upper_bound: Optional[float] = None,
|
||||
weight_block_size: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
log_once(logger.info, "Using FBGEMM fp8 optimized kernels")
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
log_once(logger.info, "Using cutlass w8a8 kernels")
|
||||
|
||||
self.dtype = dtype
|
||||
self.qweight = qweight
|
||||
self.scale = scale
|
||||
self.scale_upper_bound = (
|
||||
torch.tensor(
|
||||
[scale_upper_bound], dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
if scale_upper_bound is not None
|
||||
else None
|
||||
self.scale = scale.float()
|
||||
self.input_scale = input_scale.float() if input_scale is not None else None
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
if CUTLASS_FP8_AVAILABLE and scale_upper_bound is not None:
|
||||
self.scale_upper_bound = torch.tensor(
|
||||
scale_upper_bound, dtype=torch.float32, device=qweight.device
|
||||
)
|
||||
else:
|
||||
self.scale_upper_bound = scale_upper_bound
|
||||
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
@classmethod
|
||||
def from_unquant(cls, weight, bias, dtype):
|
||||
qweight, scale = fp8_quantize(weight, scalar=not FBGEMM_MM_AVAILABLE)
|
||||
qweight, scale = fp8_quantize(weight, scalar=not CUTLASS_FP8_AVAILABLE)
|
||||
return cls(
|
||||
qweight=qweight, scale=scale, scale_upper_bound=None, bias=bias, dtype=dtype
|
||||
qweight=qweight,
|
||||
scale=scale,
|
||||
dtype=dtype,
|
||||
bias=bias,
|
||||
input_scale=None,
|
||||
scale_upper_bound=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, input_scale, bias, dtype):
|
||||
if FBGEMM_DYN_AVAILABLE:
|
||||
# fbgemm needs float32 scales.
|
||||
scale = scale.float()
|
||||
def from_fp8(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> "Fp8Linear":
|
||||
input_scale = kwargs.get("input_scale", None)
|
||||
scale_upper_bound = kwargs.get("scale_upper_bound", None)
|
||||
weight_block_size = kwargs.get("weight_block_size", None)
|
||||
|
||||
return cls(
|
||||
qweight=weight,
|
||||
scale=scale,
|
||||
scale_upper_bound=input_scale,
|
||||
input_scale=input_scale,
|
||||
scale_upper_bound=scale_upper_bound,
|
||||
bias=bias,
|
||||
dtype=dtype,
|
||||
weight_block_size=weight_block_size,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_shared_device_identity(cls, device):
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
if device not in cls._device_identity_cache:
|
||||
cls._device_identity_cache[device] = torch.ones(1, device=device)
|
||||
return cls._device_identity_cache[device]
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if FBGEMM_MM_AVAILABLE:
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound
|
||||
)
|
||||
|
||||
y = torch.ops.fbgemm.f8f8bf16_rowwise(
|
||||
if self.weight_block_size is not None:
|
||||
# https://arxiv.org/pdf/2412.19437
|
||||
# At a more granular level. As illustrated in Figure 7 (a), (1) for activations, we group and
|
||||
# scale elements on a 1x128 tile basis (i.e., per token per 128 channels); and (2) for weights, we
|
||||
# group and scale elements on a 128x128 block basis (i.e., per 128 input channels per 128 output
|
||||
# channels).
|
||||
qinput, scale = per_token_group_quant_fp8(input, self.weight_block_size[1])
|
||||
output = w8a8_block_fp8_matmul(
|
||||
qinput,
|
||||
self.qweight,
|
||||
scale,
|
||||
self.scale,
|
||||
use_fast_accum=True,
|
||||
bias=self.bias,
|
||||
self.weight_block_size,
|
||||
output_dtype=input.dtype,
|
||||
)
|
||||
return y.to(self.dtype)
|
||||
|
||||
qinput, scale = fp8_quantize(input, scalar=True)
|
||||
output, _ = torch._scaled_mm(
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output.to(dtype=input.dtype)
|
||||
if CUTLASS_FP8_AVAILABLE:
|
||||
# cutlass FP8 supports per-token scales, so get non-scalar scales.
|
||||
qinput, scale = fp8_quantize(
|
||||
input, scale_upper_bound=self.scale_upper_bound, scalar=False
|
||||
)
|
||||
return quantization.cutlass_scaled_mm(
|
||||
qinput, self.qweight.t(), scale, self.scale, input.dtype, self.bias
|
||||
)
|
||||
|
||||
qinput, scale = fp8_quantize(
|
||||
input,
|
||||
self.input_scale,
|
||||
scale_upper_bound=self.scale_upper_bound,
|
||||
scalar=True,
|
||||
)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
self.qweight.t(),
|
||||
out_dtype=self.dtype,
|
||||
@ -275,11 +458,16 @@ class Fp8Linear(torch.nn.Module):
|
||||
scale_b=self.scale,
|
||||
bias=self.bias,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple) and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _load_scalar_or_matrix_scale(weights: Weights, prefix: str, shape: torch.Size):
|
||||
scale = weights.get_tensor(prefix, to_dtype=False)
|
||||
|
||||
if scale.numel() > 1:
|
||||
scale = weights.get_sharded(prefix, dim=0, to_dtype=False)
|
||||
return scale.reshape(-1).expand(shape[0])
|
||||
|
@ -1,14 +1,15 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
|
||||
QuantLinear = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQWeight(Weight):
|
||||
qweight: torch.Tensor
|
||||
@ -30,13 +31,8 @@ class GPTQWeight(Weight):
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
if self.use_awq_kernel:
|
||||
if SYSTEM == "rocm":
|
||||
raise NotImplementedError(
|
||||
"AWQ GEMM kernel can't be used on ROCm systems, please use `--quantize gptq` instead "
|
||||
"to use Exllama/GPTQ kernels for AWQ inference."
|
||||
)
|
||||
try:
|
||||
from text_generation_server.layers.awq.quantize.qmodule import WQLinear
|
||||
from text_generation_server.layers.awq.quantize import WQLinear
|
||||
|
||||
return WQLinear(
|
||||
w_bit=self.bits,
|
||||
@ -50,18 +46,7 @@ class GPTQWeight(Weight):
|
||||
raise NotImplementedError(
|
||||
"You do not seem to have awq installed, either install it (cd server && make install-awq), or try using GPTQ `---quantize gptq` a conversion AWQ->GPTQ will happen on the fly"
|
||||
)
|
||||
elif self.use_exllama:
|
||||
try:
|
||||
from text_generation_server.layers.gptq import ExllamaQuantLinear
|
||||
except ImportError:
|
||||
raise NotImplementedError(
|
||||
"Exllama gptq kernels are not installed. Install them `cd server/exllama_kernels && python setup.py install && cd ../exllamav2_kernels && python setup.py install`"
|
||||
)
|
||||
|
||||
return ExllamaQuantLinear(self, bias)
|
||||
else:
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
|
||||
return QuantLinear(
|
||||
self.qweight,
|
||||
self.qzeros,
|
||||
@ -118,23 +103,6 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
else:
|
||||
g_idx = None
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
HAS_EXLLAMA,
|
||||
CAN_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
@ -298,6 +266,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
self._get_gptq_params(weights)
|
||||
|
||||
use_exllama = True
|
||||
desc_act = self.desc_act
|
||||
if self.bits != 4:
|
||||
use_exllama = False
|
||||
|
||||
@ -321,7 +290,8 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
if g_idx is not None:
|
||||
if (
|
||||
not torch.equal(
|
||||
g_idx.cpu(),
|
||||
# Remove g_idx[0] to adapt the check with TP>1.
|
||||
(g_idx - g_idx[0]).cpu(),
|
||||
torch.tensor(
|
||||
[i // self.groupsize for i in range(g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
@ -332,34 +302,22 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
# Exllama implementation does not support row tensor parallelism with act-order, as
|
||||
# it would require to reorder input activations that are split unto several GPUs
|
||||
use_exllama = False
|
||||
desc_act = True
|
||||
|
||||
from text_generation_server.layers.gptq import (
|
||||
CAN_EXLLAMA,
|
||||
HAS_EXLLAMA,
|
||||
GPTQWeight,
|
||||
)
|
||||
|
||||
if use_exllama:
|
||||
if not HAS_EXLLAMA:
|
||||
if CAN_EXLLAMA:
|
||||
log_once(
|
||||
logger.warning,
|
||||
"Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True",
|
||||
)
|
||||
use_exllama = False
|
||||
else:
|
||||
log_once(logger.info, f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama and self.groupsize != -1:
|
||||
if not desc_act and self.groupsize != -1:
|
||||
qzeros = weights.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = weights.get_sharded(f"{prefix}.scales", dim=0)
|
||||
if g_idx is not None:
|
||||
# qzeros, scales sharded, and g_idx must be adjusted accordingly
|
||||
g_idx = g_idx - g_idx[0]
|
||||
else:
|
||||
qzeros = weights.get_tensor(f"{prefix}.qzeros")
|
||||
scales = weights.get_tensor(f"{prefix}.scales")
|
||||
|
||||
if use_exllama and g_idx is not None:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
if self.quantize == "gptq" and self.quant_method == "awq":
|
||||
log_once(
|
||||
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
|
||||
@ -392,7 +350,7 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
@ -400,41 +358,10 @@ class GPTQWeightsLoader(WeightsLoader):
|
||||
# before the `gptq_sym` setting tensor was added.
|
||||
self.sym = (
|
||||
weights.get_tensor("gptq_sym").item()
|
||||
if weights._has_tensor("gptq_sym")
|
||||
if weights.has_tensor("gptq_sym")
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
|
||||
|
||||
# Needs to be at the end because circular import.
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
except Exception:
|
||||
major = 1
|
||||
|
||||
HAS_EXLLAMA = False
|
||||
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
|
||||
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
|
||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||
HAS_EXLLAMA = False
|
||||
elif CAN_EXLLAMA:
|
||||
try:
|
||||
if V2:
|
||||
from text_generation_server.layers.gptq.exllamav2 import (
|
||||
QuantLinear as ExllamaQuantLinear, # noqa: F401
|
||||
create_exllama_buffers, # noqa: F401
|
||||
set_device, # noqa: F401
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "2"
|
||||
else:
|
||||
from text_generation_server.layers.gptq.exllama import (
|
||||
Ex4bitLinear as ExllamaQuantLinear, # noqa: F401
|
||||
create_exllama_buffers, # noqa: F401
|
||||
set_device, # noqa: F401
|
||||
)
|
||||
|
||||
HAS_EXLLAMA = "1"
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -1,261 +0,0 @@
|
||||
# https://github.com/fpgaminer/GPTQ-triton
|
||||
"""
|
||||
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
import math
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class Autotuner(triton.KernelInterface):
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
prune_configs_by: Dict = None,
|
||||
nearest_power_of_two: bool = False,
|
||||
):
|
||||
"""
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
|
||||
"""
|
||||
if not configs:
|
||||
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
|
||||
else:
|
||||
self.configs = configs
|
||||
self.key_idx = [arg_names.index(k) for k in key]
|
||||
self.nearest_power_of_two = nearest_power_of_two
|
||||
self.cache = {}
|
||||
# hook to reset all required tensor to zeros before relaunching a kernel
|
||||
self.hook = lambda args: 0
|
||||
if reset_to_zero is not None:
|
||||
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
|
||||
|
||||
def _hook(args):
|
||||
for i in self.reset_idx:
|
||||
args[i].zero_()
|
||||
|
||||
self.hook = _hook
|
||||
self.arg_names = arg_names
|
||||
# prune configs
|
||||
if prune_configs_by:
|
||||
perf_model, top_k = (
|
||||
prune_configs_by["perf_model"],
|
||||
prune_configs_by["top_k"],
|
||||
)
|
||||
if "early_config_prune" in prune_configs_by:
|
||||
early_config_prune = prune_configs_by["early_config_prune"]
|
||||
else:
|
||||
perf_model, top_k, early_config_prune = None, None, None
|
||||
self.perf_model, self.configs_top_k = perf_model, top_k
|
||||
self.early_config_prune = early_config_prune
|
||||
self.fn = fn
|
||||
|
||||
def _bench(self, *args, config, **meta):
|
||||
# check for conflicts, i.e. meta-parameters both provided
|
||||
# as kwargs and by the autotuner
|
||||
conflicts = meta.keys() & config.kwargs.keys()
|
||||
if conflicts:
|
||||
raise ValueError(
|
||||
f"Conflicting meta-parameters: {', '.join(conflicts)}."
|
||||
" Make sure that you don't re-define auto-tuned symbols."
|
||||
)
|
||||
# augment meta-parameters with tunable ones
|
||||
current = dict(meta, **config.kwargs)
|
||||
|
||||
def kernel_call():
|
||||
if config.pre_hook:
|
||||
config.pre_hook(self.nargs)
|
||||
self.hook(args)
|
||||
self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**current,
|
||||
)
|
||||
|
||||
try:
|
||||
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
|
||||
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
|
||||
return triton.testing.do_bench(
|
||||
kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40
|
||||
)
|
||||
except triton.OutOfResources:
|
||||
return [float("inf"), float("inf"), float("inf")]
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
if len(self.configs) > 1:
|
||||
key = tuple(args[i] for i in self.key_idx)
|
||||
|
||||
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
|
||||
# In my testing this gives decent results, and greatly reduces the amount of tuning required
|
||||
if self.nearest_power_of_two:
|
||||
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
|
||||
|
||||
if key not in self.cache:
|
||||
# prune configs
|
||||
pruned_configs = self.prune_configs(kwargs)
|
||||
bench_start = time.time()
|
||||
timings = {
|
||||
config: self._bench(*args, config=config, **kwargs)
|
||||
for config in pruned_configs
|
||||
}
|
||||
bench_end = time.time()
|
||||
self.bench_time = bench_end - bench_start
|
||||
self.cache[key] = builtins.min(timings, key=timings.get)
|
||||
self.hook(args)
|
||||
self.configs_timings = timings
|
||||
config = self.cache[key]
|
||||
else:
|
||||
config = self.configs[0]
|
||||
self.best_config = config
|
||||
if config.pre_hook is not None:
|
||||
config.pre_hook(self.nargs)
|
||||
return self.fn.run(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
|
||||
def prune_configs(self, kwargs):
|
||||
pruned_configs = self.configs
|
||||
if self.early_config_prune:
|
||||
pruned_configs = self.early_config_prune(self.configs, self.nargs)
|
||||
if self.perf_model:
|
||||
top_k = self.configs_top_k
|
||||
if isinstance(top_k, float) and top_k <= 1.0:
|
||||
top_k = int(len(self.configs) * top_k)
|
||||
if len(pruned_configs) > top_k:
|
||||
est_timing = {
|
||||
config: self.perf_model(
|
||||
**self.nargs,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
)
|
||||
for config in pruned_configs
|
||||
}
|
||||
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[
|
||||
:top_k
|
||||
]
|
||||
return pruned_configs
|
||||
|
||||
def warmup(self, *args, **kwargs):
|
||||
self.nargs = dict(zip(self.arg_names, args))
|
||||
for config in self.prune_configs(kwargs):
|
||||
self.fn.warmup(
|
||||
*args,
|
||||
num_warps=config.num_warps,
|
||||
num_stages=config.num_stages,
|
||||
**kwargs,
|
||||
**config.kwargs,
|
||||
)
|
||||
self.nargs = None
|
||||
|
||||
|
||||
def autotune(
|
||||
configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False
|
||||
):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
],
|
||||
key=['x_size'] # the two above configs will be evaluated anytime
|
||||
# the value of x_size changes
|
||||
)
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
BLOCK_SIZE = META['BLOCK_SIZE']
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
:type key: list[str]
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return Autotuner(
|
||||
fn,
|
||||
fn.arg_names,
|
||||
configs,
|
||||
key,
|
||||
reset_to_zero,
|
||||
prune_configs_by,
|
||||
nearest_power_of_two,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def matmul248_kernel_config_pruner(configs, nargs):
|
||||
"""
|
||||
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
|
||||
"""
|
||||
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
|
||||
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
|
||||
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
|
||||
|
||||
used = set()
|
||||
for config in configs:
|
||||
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
|
||||
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
|
||||
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
|
||||
group_size_m = config.kwargs["GROUP_SIZE_M"]
|
||||
|
||||
if (
|
||||
block_size_m,
|
||||
block_size_n,
|
||||
block_size_k,
|
||||
group_size_m,
|
||||
config.num_stages,
|
||||
config.num_warps,
|
||||
) in used:
|
||||
continue
|
||||
|
||||
used.add(
|
||||
(
|
||||
block_size_m,
|
||||
block_size_n,
|
||||
block_size_k,
|
||||
group_size_m,
|
||||
config.num_stages,
|
||||
config.num_warps,
|
||||
)
|
||||
)
|
||||
yield triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": block_size_m,
|
||||
"BLOCK_SIZE_N": block_size_n,
|
||||
"BLOCK_SIZE_K": block_size_k,
|
||||
"GROUP_SIZE_M": group_size_m,
|
||||
},
|
||||
num_stages=config.num_stages,
|
||||
num_warps=config.num_warps,
|
||||
)
|
@ -1,134 +0,0 @@
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
import torch
|
||||
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
|
||||
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
|
||||
"""Construct Q4Matrix, return handle"""
|
||||
return make_q4(
|
||||
qweight, qzeros, scales, g_idx if g_idx is not None else none_tensor, device
|
||||
)
|
||||
|
||||
|
||||
def ext_q4_matmul(x, q4, q4_width):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
outshape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
|
||||
|
||||
q4_matmul(x, q4, output)
|
||||
|
||||
return output.view(outshape)
|
||||
|
||||
|
||||
MAX_DQ = 1
|
||||
MAX_INNER = 1
|
||||
ACT_ORDER = False
|
||||
DEVICE = None
|
||||
|
||||
TEMP_STATE = None
|
||||
TEMP_DQ = None
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers(max_total_tokens: int):
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
assert DEVICE is not None, "call set_device first"
|
||||
|
||||
if not ACT_ORDER:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
temp_state = torch.zeros(
|
||||
(max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE
|
||||
)
|
||||
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
|
||||
|
||||
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
|
||||
prepare_buffers(DEVICE, temp_state, temp_dq)
|
||||
|
||||
matmul_recons_thd = 8
|
||||
matmul_fused_remap = False
|
||||
matmul_no_half2 = False
|
||||
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
|
||||
|
||||
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
|
||||
|
||||
|
||||
class Ex4bitLinear(torch.nn.Module):
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(self, weight: GPTQWeight, bias):
|
||||
super().__init__()
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
|
||||
assert weight.bits == 4
|
||||
|
||||
self.device = weight.qweight.device
|
||||
self.qweight = weight.qweight
|
||||
self.qzeros = weight.qzeros
|
||||
self.scales = weight.scales
|
||||
self.g_idx = weight.g_idx.cpu() if weight.g_idx is not None else None
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
if self.g_idx is not None and (
|
||||
(self.g_idx == 0).all()
|
||||
or torch.equal(
|
||||
weight.g_idx.cpu(),
|
||||
torch.tensor(
|
||||
[i // weight.groupsize for i in range(weight.g_idx.shape[0])],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
):
|
||||
self.empty_g_idx = True
|
||||
self.g_idx = None
|
||||
|
||||
assert self.device.type == "cuda"
|
||||
assert self.device.index is not None
|
||||
|
||||
self.q4 = ext_make_q4(
|
||||
self.qweight, self.qzeros, self.scales, self.g_idx, self.device.index
|
||||
)
|
||||
|
||||
self.height = weight.qweight.shape[0] * 8
|
||||
self.width = weight.qweight.shape[1]
|
||||
|
||||
# Infer groupsize from height of qzeros
|
||||
self.groupsize = None
|
||||
if self.qzeros.shape[0] > 1:
|
||||
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
|
||||
|
||||
if self.groupsize is not None:
|
||||
assert weight.groupsize == self.groupsize
|
||||
|
||||
# Handle act-order matrix
|
||||
if self.g_idx is not None:
|
||||
if self.groupsize is None:
|
||||
raise ValueError("Found group index but no groupsize. What do?")
|
||||
self.act_order = True
|
||||
else:
|
||||
self.act_order = False
|
||||
|
||||
DEVICE = self.qweight.device
|
||||
|
||||
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
|
||||
|
||||
if self.act_order:
|
||||
MAX_INNER = max(MAX_INNER, self.height, self.width)
|
||||
|
||||
ACT_ORDER = True
|
||||
|
||||
def forward(self, x):
|
||||
out = ext_q4_matmul(x, self.q4, self.width)
|
||||
|
||||
if self.bias is not None:
|
||||
out.add_(self.bias)
|
||||
return out
|
@ -1,267 +0,0 @@
|
||||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from text_generation_server.layers.exl2 import Exl2Weight
|
||||
from text_generation_server.layers.gptq import GPTQWeight
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
try:
|
||||
from exllamav2.ext import exllamav2_ext
|
||||
|
||||
make_q_matrix = exllamav2_ext.make_q_matrix
|
||||
gemm_half_q_half = exllamav2_ext.gemm_half_q_half
|
||||
except ImportError:
|
||||
log_master(logger.warning, "exllamav2_kernels not installed.")
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ExtraTensors:
|
||||
"""Additional generated quantizer tensors."""
|
||||
|
||||
q_group_map: Optional[torch.Tensor] = None
|
||||
q_invperm: Optional[torch.Tensor] = None
|
||||
q_perm: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device)
|
||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||
return output.view(output_shape)
|
||||
|
||||
|
||||
def make_group_map(q_groups: torch.Tensor, num_qrows: int):
|
||||
gr = q_groups.tolist()
|
||||
group_map = []
|
||||
num_groups = len(gr) // 2
|
||||
|
||||
for i in range(num_groups):
|
||||
bits = gr[i * 2]
|
||||
if i < num_groups - 1:
|
||||
qrows = gr[i * 2 + 3] - gr[i * 2 + 1]
|
||||
else:
|
||||
qrows = num_qrows - gr[i * 2 + 1]
|
||||
rows = qrows * 32 // bits
|
||||
for j in range(rows):
|
||||
group_map += [i]
|
||||
group_map += [rows - j]
|
||||
|
||||
return torch.tensor(group_map, dtype=torch.short, device=q_groups.device)
|
||||
|
||||
|
||||
# Create Q matrix
|
||||
|
||||
|
||||
def ext_make_q_matrix(
|
||||
w: Exl2Weight | GPTQWeight,
|
||||
extra: _ExtraTensors,
|
||||
temp_dq,
|
||||
key: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create Q matrix
|
||||
"""
|
||||
# max_dq_size = 512*(1024**2)
|
||||
# max_dq_rows = max_dq_size // out_features[0]
|
||||
max_dq_rows = 0
|
||||
|
||||
# EXL2
|
||||
if isinstance(w, Exl2Weight):
|
||||
extra.q_group_map = make_group_map(w.q_groups, w.q_weight.shape[0])
|
||||
extra.q_perm = torch.argsort(w.q_invperm).short()
|
||||
|
||||
return make_q_matrix(
|
||||
w.q_weight,
|
||||
extra.q_perm,
|
||||
w.q_invperm,
|
||||
w.q_scale,
|
||||
w.q_scale_max,
|
||||
w.q_groups,
|
||||
extra.q_group_map,
|
||||
none_tensor, # zeros
|
||||
none_tensor, # scales
|
||||
none_tensor, # g_idx
|
||||
none_tensor, # bias
|
||||
temp_dq,
|
||||
max_dq_rows,
|
||||
)
|
||||
# GPTQ
|
||||
elif isinstance(w, GPTQWeight):
|
||||
if w.scales.dtype == torch.float:
|
||||
w.scales = w.scales.half()
|
||||
|
||||
# GPTQ with g_idx (act_order)
|
||||
if w.g_idx is not None and not (w.g_idx == 0).all().item():
|
||||
extra.q_perm = torch.empty(
|
||||
(w.qweight.shape[0] * 8,),
|
||||
dtype=torch.short,
|
||||
device=w.qweight.device,
|
||||
)
|
||||
extra.q_invperm = torch.empty_like(extra.q_perm)
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
return make_q_matrix(
|
||||
w.qweight,
|
||||
extra.q_perm,
|
||||
extra.q_invperm,
|
||||
none_tensor, # q_scale
|
||||
none_tensor, # q_scale_max
|
||||
none_tensor, # q_groups
|
||||
none_tensor, # q_group_map
|
||||
w.qzeros,
|
||||
w.scales,
|
||||
w.g_idx.cpu(),
|
||||
none_tensor, # bias
|
||||
temp_dq,
|
||||
max_dq_rows,
|
||||
)
|
||||
# GPTQ without g_idx
|
||||
else:
|
||||
return make_q_matrix(
|
||||
w.qweight,
|
||||
none_tensor, # q_perm
|
||||
none_tensor, # q_invperm
|
||||
none_tensor, # q_scale
|
||||
none_tensor, # q_scale_max
|
||||
none_tensor, # q_groups
|
||||
none_tensor, # q_group_map
|
||||
w.qzeros,
|
||||
w.scales,
|
||||
none_tensor, # g_idx
|
||||
none_tensor, # bias
|
||||
temp_dq,
|
||||
max_dq_rows,
|
||||
)
|
||||
else:
|
||||
RuntimeError("Cannot create handle")
|
||||
|
||||
|
||||
DEVICE = None
|
||||
LAYERS = []
|
||||
|
||||
|
||||
def set_device(device):
|
||||
global DEVICE
|
||||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers(max_total_tokens: int):
|
||||
global LAYERS, DEVICE
|
||||
|
||||
# No need to initialize scratch space if there are no layers
|
||||
# that use ExLLamav2.
|
||||
if len(LAYERS) == 0:
|
||||
return
|
||||
|
||||
# Find the size of the scratch space.
|
||||
scratch_bytes = max(
|
||||
layer.scratch_space_fixed(max_input_len=max_total_tokens, max_batch_size=1)
|
||||
for layer in LAYERS
|
||||
)
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, scratch_bytes)
|
||||
|
||||
for layer in LAYERS:
|
||||
layer.post_init(temp_dq)
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: Exl2Weight | GPTQWeight,
|
||||
bias: torch.Tensor,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = weight
|
||||
self.extra_tensors = _ExtraTensors()
|
||||
|
||||
if isinstance(weight, Exl2Weight):
|
||||
self.infeatures = weight.q_invperm.shape[0]
|
||||
self.outfeatures = weight.q_weight.shape[1]
|
||||
elif isinstance(weight, GPTQWeight):
|
||||
if weight.bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={weight.bits}. Something is wrong in the model initialization."
|
||||
)
|
||||
|
||||
self.infeatures = weight.qweight.shape[0] // weight.bits * 32
|
||||
self.outfeatures = weight.qweight.shape[1]
|
||||
|
||||
self.padding = -self.outfeatures % 32
|
||||
self.outfeatures = self.outfeatures + self.padding
|
||||
|
||||
self.device = weight.device
|
||||
self.bias = bias if bias is not None else None
|
||||
|
||||
global LAYERS
|
||||
LAYERS.append(self)
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
device = self.q_tensors.device
|
||||
assert device.type == "cuda"
|
||||
assert device.index is not None
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
|
||||
# We NEED to keep a pointer on Python side, otherwise the garbage collector will mess with us,
|
||||
# and `Memory access fault by GPU node-2` will EAT you.
|
||||
self.temp_dq = temp_dq
|
||||
self.q_handle = ext_make_q_matrix(self.q_tensors, self.extra_tensors, temp_dq)
|
||||
|
||||
def forward(self, x, force_cuda=False):
|
||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
def temp_dq_size(self):
|
||||
return self.infeatures * self.outfeatures * 2 + 128
|
||||
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len, max_batch_size):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
class ExLlamaV2DeviceTensors:
|
||||
|
||||
device_idx: int
|
||||
scratch_bytes: int
|
||||
scratch_idx: int
|
||||
scratch: torch.tensor = None
|
||||
|
||||
def __init__(self, device, scratch_bytes):
|
||||
self.device = device
|
||||
self.scratch_bytes = scratch_bytes
|
||||
|
||||
def prepare(self):
|
||||
self.scratch = torch.empty(
|
||||
(self.scratch_bytes // 2,), dtype=torch.half, device=self.device
|
||||
)
|
||||
|
||||
def get_scratch_slice(self, size_bytes):
|
||||
|
||||
if self.scratch is None:
|
||||
self.prepare()
|
||||
|
||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||
size_half = size_bytes // 2
|
||||
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
||||
return scratch_slice
|
125
backends/gaudi/server/text_generation_server/layers/gptq/ipex.py
Normal file
125
backends/gaudi/server/text_generation_server/layers/gptq/ipex.py
Normal file
@ -0,0 +1,125 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
self.woq_linear = (
|
||||
ipex.llm.quantization.IPEXWeightOnlyQuantizedLinear.from_weight(
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.infeatures,
|
||||
self.outfeatures,
|
||||
bias=self.bias,
|
||||
group_size=self.groupsize,
|
||||
g_idx=g_idx,
|
||||
quant_method=ipex.llm.quantization.QuantMethod.GPTQ_GEMM,
|
||||
dtype=ipex.llm.quantization.QuantDtype.INT4,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [4]:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [4]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 4 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = self.woq_linear(x.reshape(-1, x.shape[-1]))
|
||||
return out.reshape(out_shape)
|
@ -1,359 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import custom_fwd
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from . import custom_autotune
|
||||
|
||||
|
||||
# code based https://github.com/fpgaminer/GPTQ-triton
|
||||
@custom_autotune.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
),
|
||||
],
|
||||
key=["M", "N", "K"],
|
||||
nearest_power_of_two=True,
|
||||
prune_configs_by={
|
||||
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
|
||||
"perf_model": None,
|
||||
"top_k": None,
|
||||
},
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_248_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
scales_ptr,
|
||||
zeros_ptr,
|
||||
g_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
maxq,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_scales,
|
||||
stride_zeros,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Compute the matrix multiplication C = A x B.
|
||||
A is of shape (M, K) float16
|
||||
B is of shape (K//8, N) int32
|
||||
C is of shape (M, N) float16
|
||||
scales is of shape (G, N) float16
|
||||
zeros is of shape (G, N) float16
|
||||
g_ptr is of shape (K) int32
|
||||
"""
|
||||
infearure_per_bits = 32 // bits
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (
|
||||
offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
|
||||
) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
a_mask = offs_am[:, None] < M
|
||||
# b_ptrs is set up such that it repeats elements along the K axis 8 times
|
||||
b_ptrs = b_ptr + (
|
||||
(offs_k[:, None] // infearure_per_bits) * stride_bk
|
||||
+ offs_bn[None, :] * stride_bn
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
|
||||
g_ptrs = g_ptr + offs_k
|
||||
# shifter is used to extract the N bits of each element in the 32-bit word from B
|
||||
scales_ptrs = scales_ptr + offs_bn[None, :]
|
||||
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
|
||||
|
||||
shifter = (offs_k % infearure_per_bits) * bits
|
||||
zeros_shifter = (offs_bn % infearure_per_bits) * bits
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k in range(0, num_pid_k):
|
||||
g_idx = tl.load(g_ptrs)
|
||||
|
||||
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
|
||||
scales = tl.load(
|
||||
scales_ptrs + g_idx[:, None] * stride_scales
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
zeros = tl.load(
|
||||
zeros_ptrs + g_idx[:, None] * stride_zeros
|
||||
) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1) & maxq # eventually avoid overflow
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
||||
# Now we need to unpack b (which is N-bit values) into 32-bit values
|
||||
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
|
||||
b = (b - zeros) * scales # Scale and shift
|
||||
|
||||
accumulator += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K
|
||||
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
|
||||
g_ptrs += BLOCK_SIZE_K
|
||||
|
||||
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
|
||||
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
with torch.cuda.device(input.device):
|
||||
output = torch.empty(
|
||||
(input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16
|
||||
)
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
matmul_248_kernel[grid](
|
||||
input,
|
||||
qweight,
|
||||
output,
|
||||
scales,
|
||||
qzeros,
|
||||
g_idx,
|
||||
input.shape[0],
|
||||
qweight.shape[1],
|
||||
input.shape[1],
|
||||
bits,
|
||||
maxq,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
qweight.stride(0),
|
||||
qweight.stride(1),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
scales.stride(0),
|
||||
qzeros.stride(0),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinearFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float16)
|
||||
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
|
||||
return output
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
|
||||
super().__init__()
|
||||
self.register_buffer("qweight", qweight)
|
||||
self.register_buffer("qzeros", qzeros)
|
||||
self.register_buffer("scales", scales)
|
||||
self.register_buffer("g_idx", g_idx)
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias)
|
||||
else:
|
||||
self.bias = None
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
self.bits = bits
|
||||
self.maxq = 2**self.bits - 1
|
||||
self.groupsize = groupsize
|
||||
|
||||
self.outfeatures = qweight.shape[1]
|
||||
self.infeatures = qweight.shape[0] * 32 // bits
|
||||
|
||||
@classmethod
|
||||
def new(cls, bits, groupsize, infeatures, outfeatures, bias):
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = torch.zeros((infeatures // 32 * bits, outfeatures), dtype=torch.int32)
|
||||
qzeros = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures // 32 * bits),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
scales = torch.zeros(
|
||||
(math.ceil(infeatures / groupsize), outfeatures), dtype=torch.float16
|
||||
)
|
||||
g_idx = torch.tensor(
|
||||
[i // groupsize for i in range(infeatures)], dtype=torch.int32
|
||||
)
|
||||
if bias:
|
||||
bias = torch.zeros((outfeatures), dtype=torch.float16)
|
||||
else:
|
||||
bias = None
|
||||
return cls(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
|
||||
|
||||
def pack(self, linear, scales, zeros, g_idx=None):
|
||||
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
|
||||
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone().half()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone().half()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]])
|
||||
/ self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros(
|
||||
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros(
|
||||
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
|
||||
)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
out = QuantLinearFunction.apply(
|
||||
x.reshape(-1, x.shape[-1]),
|
||||
self.qweight,
|
||||
self.scales,
|
||||
self.qzeros,
|
||||
self.g_idx,
|
||||
self.bits,
|
||||
self.maxq,
|
||||
)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
@ -12,7 +12,7 @@ from huggingface_hub import HfApi
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils import initialize_torch_distributed, Weights
|
||||
from text_generation_server.utils.hub import weight_files
|
||||
from text_generation_server.layers.gptq.quant_linear import QuantLinear
|
||||
from text_generation_server.layers.gptq import QuantLinear
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from text_generation_server.layers.gptq.utils import torch_snr_error
|
||||
@ -956,15 +956,24 @@ def quantize(
|
||||
|
||||
pack(model, quantizers, bits, groupsize)
|
||||
from safetensors.torch import save_file
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict = {k: v.cpu().contiguous() for k, v in state_dict.items()}
|
||||
|
||||
max_shard_size = "10GB"
|
||||
shards, index = shard_checkpoint(
|
||||
state_dict, max_shard_size=max_shard_size, weights_name="model.safetensors"
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict,
|
||||
filename_pattern="model.safetensors",
|
||||
max_shard_size=max_shard_size,
|
||||
)
|
||||
index = None
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
shards = state_dict_split.filename_to_tensors
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
for shard_file, shard in shards.items():
|
||||
save_file(
|
||||
|
@ -1,9 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from accelerate import init_empty_weights
|
||||
from text_generation_server.utils.import_utils import (
|
||||
SYSTEM,
|
||||
)
|
||||
|
||||
|
||||
# Monkey patching
|
||||
@ -33,48 +30,8 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||
torch.nn.LayerNorm.load = load_layer_norm
|
||||
torch.nn.LayerNorm.load_no_bias = load_layer_norm_no_bias
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
@ -82,21 +39,6 @@ elif SYSTEM == "rocm":
|
||||
|
||||
return super().forward(hidden_states), residual
|
||||
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
out = ipex.llm.functional.add_layer_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
residual is not None,
|
||||
)
|
||||
return out, residual if residual is not None else hidden_states
|
||||
|
||||
|
||||
class FastRMSNorm(nn.Module):
|
||||
def __init__(self, weight: torch.Tensor, eps: float):
|
||||
@ -111,74 +53,15 @@ class FastRMSNorm(nn.Module):
|
||||
return cls(weight, eps)
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if SYSTEM == "ipex":
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
residual is not None,
|
||||
)
|
||||
return out, residual if residual is not None else hidden_states
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
|
||||
orig_shape = hidden_states.shape
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
elif SYSTEM == "cuda":
|
||||
# faster post attention rms norm
|
||||
(
|
||||
normed_hidden_states,
|
||||
res,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
elif SYSTEM == "rocm":
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return out, residual
|
||||
residual += hidden_states.view(residual.shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
residual = hidden_states
|
||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||
if len(orig_shape) == 2:
|
||||
residual = residual.unsqueeze(0)
|
||||
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||
return x.view(orig_shape), residual.view(orig_shape)
|
||||
|
@ -1,21 +1,5 @@
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from torch.nn import functional as F
|
||||
import os
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
ROCM_USE_SKINNY_GEMM = os.getenv("ROCM_USE_SKINNY_GEMM", "True").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
)
|
||||
|
||||
if ROCM_USE_SKINNY_GEMM:
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
f"Could not load `vllm._custom_C` for ROCm skinny gemm. Full error: {e}"
|
||||
)
|
||||
|
||||
|
||||
class FastLinear(torch.nn.Module):
|
||||
@ -44,83 +28,11 @@ class FastLinear(torch.nn.Module):
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
|
||||
|
||||
class FastLinearROCm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
weight,
|
||||
bias,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(weight)
|
||||
if bias is not None:
|
||||
self.bias = torch.nn.Parameter(bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
self.cu_count = torch.cuda.get_device_properties(
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
self.use_skinny_gemm = (
|
||||
ROCM_USE_SKINNY_GEMM
|
||||
and "gfx1" not in torch.cuda.get_device_properties("cuda").gcnArchName
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix: str, weights, bias: bool):
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
if bias:
|
||||
bias = weights.get_tensor(f"{prefix}.bias")
|
||||
else:
|
||||
bias = None
|
||||
return cls(weight, bias)
|
||||
|
||||
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
bias = self.bias
|
||||
|
||||
if (
|
||||
self.use_skinny_gemm
|
||||
and inp.dtype == torch.float16
|
||||
and inp.shape[-1] % 8 == 0
|
||||
):
|
||||
batched = False
|
||||
inp_shape = inp.shape
|
||||
|
||||
if inp.dim() == 3:
|
||||
inp = inp.view(-1, inp_shape[-1])
|
||||
batched = True
|
||||
|
||||
m, n, k = weight.shape[0], inp_shape[0], inp_shape[1]
|
||||
if m > 8 and n <= 4:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.wvSpltK(weight, inp, out, n, self.cu_count)
|
||||
elif m % 4 == 0 and n == 1 and k <= 8192:
|
||||
out = torch.empty(
|
||||
inp_shape[0], weight.shape[0], dtype=inp.dtype, device=weight.device
|
||||
)
|
||||
_custom_C.LLMM1(weight, inp, out, 4)
|
||||
else:
|
||||
out = F.linear(inp, weight)
|
||||
|
||||
if batched:
|
||||
out.view(*inp_shape[:-1], out.shape[-1])
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
return F.linear(inp, self.weight, self.bias)
|
||||
|
||||
|
||||
def get_linear(weight, bias):
|
||||
# Weights that are loaded through methods that are not
|
||||
# quantization-aware are still bare tensors. We may want
|
||||
# to change this in the future.
|
||||
if isinstance(weight, torch.Tensor):
|
||||
if SYSTEM == "rocm":
|
||||
return FastLinearROCm(weight, bias)
|
||||
else:
|
||||
return FastLinear(weight, bias)
|
||||
|
||||
return weight.get_linear(bias)
|
||||
|
@ -2,19 +2,15 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from loguru import logger
|
||||
from text_generation_server.layers.fp8 import fp8_quantize
|
||||
from text_generation_server.layers.marlin.gptq import _check_valid_shape
|
||||
from text_generation_server.layers.marlin.util import (
|
||||
_check_marlin_kernels,
|
||||
permute_scales,
|
||||
)
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
MARLIN_TILE_SIZE = 16
|
||||
@ -34,9 +30,7 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
|
||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||
assert quantization is not None
|
||||
|
||||
scales = scales.unsqueeze(0)
|
||||
if scales.shape[1] == 1:
|
||||
@ -62,14 +56,21 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||
|
||||
@classmethod
|
||||
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
|
||||
def from_fp8(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
scale: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.fp8_marlin_gemm(
|
||||
C = quantization.fp8_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
@ -131,7 +132,7 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
||||
qweight = pack_fp8_as_int32(weight.t())
|
||||
|
||||
perm = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, 8
|
||||
)
|
||||
|
||||
|
@ -11,14 +11,12 @@ from text_generation_server.layers.marlin.util import (
|
||||
permute_scales,
|
||||
unpack_cols,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
quantization = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
@ -35,17 +33,7 @@ MARLIN_TILE_SIZE = 16
|
||||
def can_use_gptq_marlin(
|
||||
*, bits: int, groupsize: int, quant_method: str, quantize: str, sym: bool
|
||||
) -> bool:
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and marlin_kernels is not None
|
||||
and has_sm_8_0
|
||||
and quantize in {"awq", "gptq"}
|
||||
and quant_method in {"awq", "gptq"}
|
||||
and bits in GPTQ_MARLIN_BITS
|
||||
and groupsize in GPTQ_MARLIN_GROUP_SIZES
|
||||
# We only suppord asymmetric quantization for AWQ.
|
||||
and (sym or quant_method == "awq")
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
@ -231,7 +219,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
)
|
||||
|
||||
def _get_gptq_params(self, weights: Weights):
|
||||
if weights._has_tensor("gptq_bits") and weights._has_tensor("gptq_groupsize"):
|
||||
if weights.has_tensor("gptq_bits") and weights.has_tensor("gptq_groupsize"):
|
||||
self.bits = weights.get_tensor("gptq_bits").item()
|
||||
self.groupsize = weights.get_tensor("gptq_groupsize").item()
|
||||
self.desc_act = False
|
||||
@ -239,7 +227,7 @@ class GPTQMarlinWeightsLoader(WeightsLoader):
|
||||
# before the `gptq_sym` setting tensor was added.
|
||||
self.sym = (
|
||||
weights.get_tensor("gptq_sym").item()
|
||||
if weights._has_tensor("gptq_sym")
|
||||
if weights.has_tensor("gptq_sym")
|
||||
else False
|
||||
)
|
||||
self.quant_method = "gptq"
|
||||
@ -261,7 +249,7 @@ class GPTQMarlinWeight(Weight):
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.qweight.dtype == torch.int32
|
||||
assert self.scales.dtype == torch.float16
|
||||
assert self.scales.dtype in (torch.float16, torch.bfloat16)
|
||||
assert self.g_idx.dtype == torch.int32
|
||||
assert self.perm.dtype == torch.int32
|
||||
|
||||
@ -287,7 +275,7 @@ def repack_gptq_for_marlin(
|
||||
) -> GPTQMarlinWeight:
|
||||
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
if bits not in GPTQ_MARLIN_BITS:
|
||||
supported_bits = ", ".join(str(b) for b in GPTQ_MARLIN_BITS)
|
||||
@ -300,7 +288,7 @@ def repack_gptq_for_marlin(
|
||||
raise RuntimeError(
|
||||
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
if not (sym or quant_method == "awq"):
|
||||
if not (sym or quant_method == "awq" or quant_method == "compressed-tensors"):
|
||||
raise RuntimeError(
|
||||
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
|
||||
)
|
||||
@ -330,7 +318,7 @@ def repack_gptq_for_marlin(
|
||||
g_idx = torch.empty(0, dtype=torch.int, device=qweight.device)
|
||||
|
||||
if quant_method == "awq":
|
||||
repacked = marlin_kernels.awq_marlin_repack(
|
||||
repacked = quantization.awq_marlin_repack(
|
||||
qweight, in_features, out_features, bits
|
||||
)
|
||||
if qzeros is not None:
|
||||
@ -342,7 +330,7 @@ def repack_gptq_for_marlin(
|
||||
)
|
||||
|
||||
else:
|
||||
repacked = marlin_kernels.gptq_marlin_repack(
|
||||
repacked = quantization.gptq_marlin_repack(
|
||||
qweight, perm, in_features, out_features, bits
|
||||
)
|
||||
|
||||
@ -379,13 +367,26 @@ class GPTQMarlinLinear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.qweight.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.scales.shape[1]
|
||||
_check_valid_shape(in_features=in_features, out_features=out_features)
|
||||
|
||||
self.bits = weight.bits
|
||||
if weight.bits not in (4, 8):
|
||||
raise ValueError("GPTQMarlinLinear only supports 4 and 8-bit quantization")
|
||||
|
||||
if weight.qzeros.numel() > 0:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8
|
||||
else:
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
|
||||
self.is_full_k = weight.is_full_k
|
||||
|
||||
self.qweight = weight.qweight
|
||||
@ -403,10 +404,10 @@ class GPTQMarlinLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
A_flat = A.view(-1, A.shape[-1])
|
||||
C = marlin_kernels.gptq_marlin_gemm(
|
||||
C = quantization.gptq_marlin_gemm(
|
||||
A_flat,
|
||||
self.qweight,
|
||||
self.scales,
|
||||
@ -414,7 +415,7 @@ class GPTQMarlinLinear(nn.Module):
|
||||
self.g_idx,
|
||||
self.perm,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
self.quant_type,
|
||||
A_flat.shape[0],
|
||||
self.scales.shape[1],
|
||||
A_flat.shape[1],
|
||||
|
@ -3,13 +3,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.layers.marlin.util import _check_marlin_kernels
|
||||
from text_generation_server.utils.weights import Weight, Weights, WeightsLoader
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
quantization = None
|
||||
|
||||
|
||||
class MarlinWeightsLoader(WeightsLoader):
|
||||
@ -34,7 +32,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||
|
||||
B_meta = weights.get_tensor(f"{prefix}.B_meta")
|
||||
s = weights.get_tensor(f"{prefix}.s")
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_tensor(f"{prefix}.B")
|
||||
@ -65,7 +65,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||
f"{prefix}.s", dim=1, block_sizes=block_sizes
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
B = weights.get_packed_sharded(
|
||||
f"{prefix}.B", dim=1, block_sizes=block_sizes
|
||||
@ -96,7 +98,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||
[weights.get_sharded(f"{p}.s", dim=1) for p in prefixes], dim=1
|
||||
)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = torch.cat(
|
||||
@ -132,7 +136,9 @@ class MarlinWeightsLoader(WeightsLoader):
|
||||
else:
|
||||
s = weights.get_sharded(f"{prefix}.s", dim=0)
|
||||
|
||||
weight = GPTQMarlin24Weight(B=B, B_meta=B_meta, s=s, bits=self.bits)
|
||||
weight = GPTQMarlin24Weight(
|
||||
weight_packed=B, meta=B_meta, scale_packed=s, bits=self.bits
|
||||
)
|
||||
else:
|
||||
try:
|
||||
B = weights.get_sharded(f"{prefix}.B", dim=0)
|
||||
@ -179,7 +185,7 @@ class MarlinLinear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE
|
||||
out_features = weight.s.shape[1]
|
||||
@ -208,9 +214,9 @@ class MarlinLinear(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
C = marlin_kernels.marlin_gemm(
|
||||
C = quantization.marlin_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.s,
|
||||
@ -247,15 +253,15 @@ class GPTQMarlin24Weight:
|
||||
bits: quantized weight size.
|
||||
"""
|
||||
|
||||
B: torch.Tensor
|
||||
B_meta: torch.Tensor
|
||||
s: torch.Tensor
|
||||
weight_packed: torch.Tensor
|
||||
meta: torch.Tensor
|
||||
scale_packed: torch.Tensor
|
||||
bits: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.B.dtype == torch.int32
|
||||
assert self.B_meta.dtype == torch.int16
|
||||
assert self.s.dtype == torch.float16
|
||||
assert self.weight_packed.dtype == torch.int32
|
||||
assert self.meta.dtype == torch.int16
|
||||
assert self.scale_packed.dtype == torch.float16
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
return GPTQMarlin24Linear(
|
||||
@ -269,7 +275,7 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
_check_marlin_kernels()
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
if weight.bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
|
||||
supported_bits = ", ".join(
|
||||
@ -279,9 +285,13 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
f"{weight.bits}-bit GPTQ Sparse 2:4 Marlin is not supported, must be one of: {supported_bits}"
|
||||
)
|
||||
|
||||
in_features = weight.B.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.s.shape[1]
|
||||
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
|
||||
in_features = weight.weight_packed.shape[0] * MARLIN_TILE_SIZE * 2
|
||||
out_features = weight.scale_packed.shape[1]
|
||||
groupsize = (
|
||||
-1
|
||||
if weight.scale_packed.shape[0] == 1
|
||||
else in_features // weight.scale_packed.shape[0]
|
||||
)
|
||||
|
||||
if groupsize not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
supported_sizes = ", ".join(
|
||||
@ -291,8 +301,11 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
f"Group size {groupsize} is not supported, must be one of: {supported_sizes}"
|
||||
)
|
||||
|
||||
self.bits = weight.bits
|
||||
weights_per_int32 = 32 // self.bits
|
||||
if weight.bits == 4:
|
||||
self.quant_type = quantization.scalar_types.uint4b8
|
||||
else:
|
||||
self.quant_type = quantization.scalar_types.uint8b128
|
||||
weights_per_int32 = 32 // weight.bits
|
||||
|
||||
assert (
|
||||
out_features % GPTQ_MARLIN_24_MIN_THREAD_N == 0
|
||||
@ -309,9 +322,9 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
f"Number of input features ({in_features}) not divisable by group size ({groupsize})"
|
||||
)
|
||||
|
||||
self.B = weight.B
|
||||
self.B_meta = weight.B_meta
|
||||
self.s = weight.s
|
||||
self.weight_packed = weight.weight_packed
|
||||
self.meta = weight.meta
|
||||
self.scale_packed = weight.scale_packed
|
||||
if bias is not None:
|
||||
self.bias = bias
|
||||
else:
|
||||
@ -320,25 +333,25 @@ class GPTQMarlin24Linear(nn.Module):
|
||||
self.workspace = torch.zeros(
|
||||
(out_features // GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
dtype=torch.int,
|
||||
device=weight.B.device,
|
||||
device=weight.weight_packed.device,
|
||||
)
|
||||
|
||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||
assert marlin_kernels is not None
|
||||
assert quantization is not None
|
||||
|
||||
C = marlin_kernels.gptq_marlin_24_gemm(
|
||||
C = quantization.gptq_marlin_24_gemm(
|
||||
A.view(-1, A.shape[-1]),
|
||||
self.B,
|
||||
self.B_meta,
|
||||
self.s,
|
||||
self.weight_packed,
|
||||
self.meta,
|
||||
self.scale_packed,
|
||||
self.workspace,
|
||||
self.bits,
|
||||
self.quant_type,
|
||||
A.shape[0],
|
||||
self.s.shape[1],
|
||||
self.scale_packed.shape[1],
|
||||
A.shape[1],
|
||||
)
|
||||
|
||||
C = C.reshape(A.shape[:-1] + (self.s.shape[1],))
|
||||
C = C.reshape(A.shape[:-1] + (self.scale_packed.shape[1],))
|
||||
|
||||
if self.bias is not None:
|
||||
C += self.bias
|
||||
|
@ -3,12 +3,9 @@ from typing import List, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
try:
|
||||
import marlin_kernels
|
||||
except ImportError:
|
||||
marlin_kernels = None
|
||||
|
||||
quantization = None
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
@ -18,12 +15,11 @@ except Exception:
|
||||
|
||||
|
||||
def _check_marlin_kernels():
|
||||
if not (SYSTEM == "cuda" and has_sm_8_0):
|
||||
raise NotImplementedError(
|
||||
"Using quantized Marlin models requires a GPU with CUDA capability 8.0 or later."
|
||||
)
|
||||
|
||||
if marlin_kernels is None:
|
||||
if quantization is None:
|
||||
raise NotImplementedError(
|
||||
"marlin is not installed, install it with: pip install server/marlin"
|
||||
)
|
||||
|
@ -10,13 +10,8 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
from text_generation_server.layers.marlin import GPTQMarlinWeightsLoader
|
||||
from text_generation_server.layers.moe.gptq_marlin import (
|
||||
GPTQMarlinSparseMoELayer,
|
||||
can_use_marlin_moe_gemm,
|
||||
)
|
||||
from text_generation_server.layers.moe.unquantized import UnquantizedSparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.moe.fp8 import FP8SparseMoELayer
|
||||
from text_generation_server.utils.log import log_once
|
||||
from text_generation_server.utils.weights import (
|
||||
DefaultWeightsLoader,
|
||||
@ -24,12 +19,7 @@ from text_generation_server.utils.weights import (
|
||||
UnquantizedWeight,
|
||||
)
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from .fused_moe_rocm import grouped_topk
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_topk, grouped_topk
|
||||
|
||||
from .fused_moe_ipex import fused_topk, grouped_topk
|
||||
|
||||
# NOTE: we are using a protocol here, because multiple inherance is not nice.
|
||||
# We need `Module`, and `Module` -> some abstract class -> some concrete
|
||||
@ -52,6 +42,8 @@ class MoELayer(Protocol):
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
scoring_func: Optional[str] = None,
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
): ...
|
||||
|
||||
def forward(
|
||||
@ -81,9 +73,14 @@ class DenseMoELayer(nn.Module):
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
hidden_act: str = "silu",
|
||||
scoring_func: Optional[str] = None,
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert scoring_func is None, "scoring func is not handled"
|
||||
assert e_score_correction_bias is None, "scoring correction bias is not handled"
|
||||
|
||||
log_once(
|
||||
logger.info,
|
||||
"No fused layers are available for this model type, using (slower) dense MoE layer",
|
||||
@ -199,22 +196,27 @@ class SparseMoELayer(nn.Module):
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
scoring_func: Optional[str] = "softmax",
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if (
|
||||
isinstance(weights.loader, DefaultWeightsLoader)
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader):
|
||||
if (
|
||||
isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
and weights.loader.to_fp8
|
||||
):
|
||||
cls = FP8SparseMoELayer
|
||||
else:
|
||||
cls = UnquantizedSparseMoELayer
|
||||
elif isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym:
|
||||
cls = GPTQMarlinSparseMoELayer
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, sparse MoE is only supported for unquantized and GPTQ weights"
|
||||
f"Unsupported weights loader: {type(weights.loader)}, sparse MoE is only supported for unquantized, AWQ, and GPTQ weights"
|
||||
)
|
||||
|
||||
log_once(
|
||||
@ -230,6 +232,8 @@ class SparseMoELayer(nn.Module):
|
||||
topk=topk,
|
||||
topk_group=topk_group,
|
||||
weights=weights,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
down_proj_name=down_proj_name,
|
||||
@ -241,17 +245,6 @@ class SparseMoELayer(nn.Module):
|
||||
@staticmethod
|
||||
def is_supported(weights: Weights) -> bool:
|
||||
return (
|
||||
(
|
||||
isinstance(weights.loader, DefaultWeightsLoader)
|
||||
and isinstance(weights.loader.weight_class, UnquantizedWeight)
|
||||
)
|
||||
or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
or (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader)
|
||||
and can_use_marlin_moe_gemm(
|
||||
quant_method=weights.loader.quant_method,
|
||||
quantize=weights.loader.quantize,
|
||||
sym=weights.loader.sym,
|
||||
)
|
||||
)
|
||||
)
|
||||
) or isinstance(weights.loader, HybridFP8UnquantLoader)
|
||||
|
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
173
backends/gaudi/server/text_generation_server/layers/moe/fp8.py
Normal file
@ -0,0 +1,173 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.fp8 import (
|
||||
Fp8Weight,
|
||||
fp8_quantize,
|
||||
quant_dtype,
|
||||
normalize_e4m3fn_to_native_float8,
|
||||
)
|
||||
|
||||
try:
|
||||
from .unquantized import fused_moe
|
||||
except Exception:
|
||||
fused_moe = None
|
||||
|
||||
|
||||
class FP8SparseMoELayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
scoring_func: Optional[str] = "softmax",
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
|
||||
(
|
||||
self.gate_up_proj,
|
||||
self.gate_up_proj_weight_scale,
|
||||
self.gate_up_proj_input_scale,
|
||||
) = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
gate_proj_name=gate_proj_name,
|
||||
up_proj_name=up_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
|
||||
_load_expert_weights_row(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
name=down_proj_name,
|
||||
weights=weights,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
w2=self.down_proj,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=self.gate_up_proj_weight_scale,
|
||||
w2_scale=self.down_proj_weight_scale,
|
||||
a1_scale=self.gate_up_proj_input_scale,
|
||||
a2_scale=self.down_proj_input_scale,
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_weights(
|
||||
get_weight_fn,
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
all_weight = None
|
||||
all_weight_scales = None
|
||||
max_input_scale = None
|
||||
|
||||
for i in range(n_experts):
|
||||
weight = get_weight_fn(prefix, i, name, weights)
|
||||
|
||||
assert isinstance(weight, Fp8Weight)
|
||||
|
||||
if all_weight is None:
|
||||
all_weight = torch.empty(
|
||||
(n_experts,) + weight.weight.shape,
|
||||
dtype=quant_dtype,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
if all_weight_scales is None:
|
||||
all_weight_scales = torch.empty(
|
||||
(n_experts,) + weight.weight_scale.shape,
|
||||
dtype=torch.float32,
|
||||
device=weight.weight.device,
|
||||
)
|
||||
|
||||
if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
|
||||
all_weight[i], all_weight_scales[i], current_input_scale = (
|
||||
normalize_e4m3fn_to_native_float8(
|
||||
weight.weight, weight.weight_scale, weight.input_scale
|
||||
)
|
||||
)
|
||||
if current_input_scale is not None:
|
||||
if max_input_scale is None or current_input_scale > max_input_scale:
|
||||
max_input_scale = current_input_scale
|
||||
else:
|
||||
all_weight[i], all_weight_scales[i] = fp8_quantize(
|
||||
weight.weight, scalar=True
|
||||
)
|
||||
|
||||
assert all_weight is not None
|
||||
|
||||
return all_weight, all_weight_scales, max_input_scale
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
gate_proj_name: str,
|
||||
up_proj_name: str,
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
|
||||
)
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=None, weights=weights
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_weights_row(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> torch.Tensor:
|
||||
def get_weight_fn(prefix, i, name, weights):
|
||||
return weights.get_weights_row(f"{prefix}.{i}.{name}")
|
||||
|
||||
return _load_expert_weights(
|
||||
get_weight_fn, prefix=prefix, n_experts=n_experts, name=name, weights=weights
|
||||
)
|
@ -16,10 +16,8 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
# TODO: Remove the functions once moe_kernel are built for ROCM
|
||||
def grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@ -50,3 +48,18 @@ def grouped_topk(
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def fused_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
topk_weights = torch.nn.functional.softmax(
|
||||
gating_output, dim=1, dtype=torch.float32
|
||||
)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids
|
@ -1,215 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
from text_generation_server.layers.marlin.gptq import (
|
||||
GPTQMarlinWeight,
|
||||
GPTQMarlinWeightsLoader,
|
||||
)
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
from moe_kernels.fused_marlin_moe import fused_marlin_moe
|
||||
else:
|
||||
fused_marlin_moe = None
|
||||
|
||||
|
||||
try:
|
||||
major, _minor = torch.cuda.get_device_capability()
|
||||
has_sm_8_0 = major >= 8
|
||||
except Exception:
|
||||
has_sm_8_0 = False
|
||||
|
||||
|
||||
def can_use_marlin_moe_gemm(
|
||||
*,
|
||||
quant_method: str,
|
||||
quantize: str,
|
||||
sym: bool,
|
||||
):
|
||||
return (
|
||||
SYSTEM == "cuda"
|
||||
and fused_marlin_moe is not None
|
||||
and has_sm_8_0
|
||||
and quantize == "gptq"
|
||||
and quant_method == "gptq"
|
||||
and sym
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQMarlinMoEWeight:
|
||||
qweight: torch.Tensor
|
||||
qzeros: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
g_idx: torch.Tensor
|
||||
perm: torch.Tensor
|
||||
is_full_k: bool
|
||||
|
||||
|
||||
class GPTQMarlinSparseMoELayer(nn.Module):
|
||||
"""
|
||||
MoE layer that uses a fused GPTQ-Marlin kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_expert_group: Optional[int],
|
||||
n_experts: int,
|
||||
prefix: str,
|
||||
renormalize: bool,
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not (
|
||||
isinstance(weights.loader, GPTQMarlinWeightsLoader) and weights.loader.sym
|
||||
):
|
||||
raise ValueError(
|
||||
f"Unsupported weights loader: {weights.loader}, only GPTQMarlinWeightsLoader with symmetric quantization is supported"
|
||||
)
|
||||
|
||||
assert (n_expert_group is None) == (
|
||||
topk_group is None
|
||||
), "n_expert_group and topk_group must both be None or have some value"
|
||||
|
||||
self.n_expert_group = n_expert_group
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
n_experts=n_experts,
|
||||
names=[gate_proj_name, up_proj_name],
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.down_proj = _load_expert_weights_row(
|
||||
prefix=prefix, n_experts=n_experts, name=down_proj_name, weights=weights
|
||||
)
|
||||
|
||||
self.bits = weights.loader.bits
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj.qweight,
|
||||
w2=self.down_proj.qweight,
|
||||
g_idx1=self.gate_up_proj.g_idx,
|
||||
g_idx2=self.down_proj.g_idx,
|
||||
perm1=self.gate_up_proj.perm,
|
||||
perm2=self.down_proj.perm,
|
||||
w1_scale=self.gate_up_proj.scales,
|
||||
w2_scale=self.down_proj.scales,
|
||||
is_full_k1=self.gate_up_proj.is_full_k,
|
||||
is_full_k2=self.down_proj.is_full_k,
|
||||
gating_output=gating_output,
|
||||
topk=self.topk,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
num_bits=self.bits,
|
||||
)
|
||||
|
||||
|
||||
def _load_expert_multi_weights_col(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
names: List[str],
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_multi_weights_col(
|
||||
[f"{prefix}.{i}.{name}" for name in names], 0
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _load_expert_weights_row(
|
||||
*,
|
||||
prefix: str,
|
||||
n_experts: int,
|
||||
name: str,
|
||||
weights: Weights,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
moe_weight = None
|
||||
for i in range(n_experts):
|
||||
weight = weights.get_weights_row(
|
||||
f"{prefix}.{i}.{name}",
|
||||
)
|
||||
assert isinstance(weight, GPTQMarlinWeight)
|
||||
moe_weight = _pack_weight(
|
||||
n_experts=n_experts, expert=i, weight=weight, moe_weight=moe_weight
|
||||
)
|
||||
assert moe_weight is not None
|
||||
return moe_weight
|
||||
|
||||
|
||||
def _pack_weight(
|
||||
*,
|
||||
n_experts: int,
|
||||
expert: int,
|
||||
moe_weight: Optional[GPTQMarlinMoEWeight],
|
||||
weight: GPTQMarlinWeight,
|
||||
) -> GPTQMarlinMoEWeight:
|
||||
if moe_weight is None:
|
||||
qweight = torch.empty(
|
||||
(n_experts,) + weight.qweight.shape,
|
||||
dtype=weight.qweight.dtype,
|
||||
device=weight.qweight.device,
|
||||
)
|
||||
qzeros = torch.empty(
|
||||
(n_experts,) + weight.qzeros.shape,
|
||||
dtype=weight.qzeros.dtype,
|
||||
device=weight.qzeros.device,
|
||||
)
|
||||
scales = torch.empty(
|
||||
(n_experts,) + weight.scales.shape,
|
||||
dtype=weight.scales.dtype,
|
||||
device=weight.scales.device,
|
||||
)
|
||||
g_idx = torch.empty(
|
||||
(n_experts,) + weight.g_idx.shape,
|
||||
dtype=weight.g_idx.dtype,
|
||||
device=weight.g_idx.device,
|
||||
)
|
||||
perm = torch.empty(
|
||||
(n_experts,) + weight.perm.shape,
|
||||
dtype=weight.perm.dtype,
|
||||
device=weight.perm.device,
|
||||
)
|
||||
|
||||
moe_weight = GPTQMarlinMoEWeight(
|
||||
qweight=qweight,
|
||||
qzeros=qzeros,
|
||||
scales=scales,
|
||||
g_idx=g_idx,
|
||||
perm=perm,
|
||||
is_full_k=weight.is_full_k,
|
||||
)
|
||||
|
||||
moe_weight.qweight[expert] = weight.qweight
|
||||
moe_weight.qzeros[expert] = weight.qzeros
|
||||
moe_weight.scales[expert] = weight.scales
|
||||
moe_weight.g_idx[expert] = weight.g_idx
|
||||
moe_weight.perm[expert] = weight.perm
|
||||
|
||||
return moe_weight
|
@ -1,15 +1,11 @@
|
||||
from typing import Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import UnquantizedWeight, Weights
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
elif SYSTEM != "ipex":
|
||||
from moe_kernels.fused_moe import fused_moe
|
||||
moe_kernels = None
|
||||
|
||||
|
||||
class UnquantizedSparseMoELayer(nn.Module):
|
||||
@ -23,6 +19,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
topk: int,
|
||||
topk_group: Optional[int],
|
||||
weights: Weights,
|
||||
scoring_func: Optional[str] = "softmax",
|
||||
e_score_correction_bias: Optional[float] = None,
|
||||
gate_proj_name: str = "gate_proj",
|
||||
up_proj_name: str = "up_proj",
|
||||
down_proj_name: str = "down_proj",
|
||||
@ -37,6 +35,9 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
self.topk = topk
|
||||
self.topk_group = topk_group
|
||||
self.renormalize = renormalize
|
||||
self.weight_block_size = weights.weights_loader.weight_block_size
|
||||
self.scoring_func = scoring_func
|
||||
self.e_score_correction_bias = e_score_correction_bias
|
||||
|
||||
self.gate_up_proj = _load_expert_multi_weights_col(
|
||||
prefix=prefix,
|
||||
@ -54,17 +55,6 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
|
||||
if SYSTEM == "rocm":
|
||||
return fused_moe(
|
||||
x,
|
||||
self.gate_up_proj,
|
||||
self.down_proj,
|
||||
gating_output,
|
||||
self.topk,
|
||||
renormalize=self.renormalize,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
w1=self.gate_up_proj,
|
||||
@ -76,6 +66,8 @@ class UnquantizedSparseMoELayer(nn.Module):
|
||||
use_grouped_topk=self.n_expert_group is not None,
|
||||
num_expert_group=self.n_expert_group,
|
||||
topk_group=self.topk_group,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -136,3 +128,110 @@ def _load_expert_weights_row(
|
||||
assert all_weight is not None
|
||||
|
||||
return all_weight
|
||||
|
||||
|
||||
def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
- a1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a1.
|
||||
- a2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
a2.
|
||||
- block_shape: (Optional[List[int]]): Optional block size for block-wise
|
||||
quantization.
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
from loguru import logger
|
||||
import inspect
|
||||
|
||||
logger.info(f"{inspect.signature(moe_kernels.grouped_topk)}")
|
||||
topk_weights, topk_ids = moe_kernels.grouped_topk(
|
||||
hidden_states,
|
||||
gating_output,
|
||||
topk,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
topk_weights, topk_ids = moe_kernels.fused_topk(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states, gating_output, topk, renormalize
|
||||
)
|
||||
|
||||
return moe_kernels.fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
inplace=inplace,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
@ -2,14 +2,10 @@ import os
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import rotary_emb
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
def _create_inv_freq(dim, base, device):
|
||||
@ -30,7 +26,7 @@ def _get_rope_config(config):
|
||||
|
||||
|
||||
class PositionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, inv_freq, scaling_factor):
|
||||
def __init__(self, inv_freq, scaling_factor, max_position_embeddings):
|
||||
super().__init__()
|
||||
self.inv_freq = inv_freq
|
||||
self._seq_len_cached = 0
|
||||
@ -40,6 +36,9 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._sin_k_cached = None
|
||||
self.scaling_factor = scaling_factor
|
||||
self.dynamic_args = None
|
||||
self._update_cos_sin_cache(
|
||||
torch.float32, inv_freq.device, max_position_embeddings
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -48,34 +47,30 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if SYSTEM == "cuda":
|
||||
rotary_dim = cos.shape[-1]
|
||||
q1 = query[..., :rotary_dim]
|
||||
q2 = query[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
k1 = key[..., :rotary_dim]
|
||||
k2 = key[..., rotary_dim : 2 * rotary_dim]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif SYSTEM == "rocm":
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
head_size = query.shape[-1]
|
||||
# HPU RoPE kernel requires hidden dimension for cos and sin to be equal
|
||||
# to query hidden dimension, so the original tensors need to be
|
||||
# expanded
|
||||
# GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE
|
||||
# and expansion of cos/sin tensors via concatenation
|
||||
rope_mode = RotaryPosEmbeddingMode.BLOCKWISE
|
||||
cos = torch.cat((cos, cos), dim=-1)
|
||||
sin = torch.cat((sin, sin), dim=-1)
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||
elif SYSTEM == "ipex":
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), True
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||
|
||||
@classmethod
|
||||
def static(cls, config, dim, base, device):
|
||||
@ -89,6 +84,14 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
|
||||
if rope_type == "linear":
|
||||
pass
|
||||
elif rope_type == "default":
|
||||
pass
|
||||
elif rope_type == "mrope":
|
||||
mrope_section = rope_scaling["mrope_section"]
|
||||
if mrope_section is not None:
|
||||
return RotaryPositionEmbeddingMultimodalSections(
|
||||
inv_freq, scaling_factor, mrope_section
|
||||
)
|
||||
elif rope_type == "dynamic":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
return DynamicPositionRotaryEmbedding(
|
||||
@ -109,7 +112,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
],
|
||||
)
|
||||
|
||||
return cls(inv_freq, scaling_factor)
|
||||
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||
|
||||
elif rope_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
@ -190,7 +193,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||
|
||||
@classmethod
|
||||
def load(cls, config, prefix, weights):
|
||||
@ -236,7 +239,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
raise NotImplementedError(
|
||||
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||
)
|
||||
return cls(inv_freq, scaling_factor)
|
||||
return cls(inv_freq, scaling_factor, config.max_position_embeddings)
|
||||
|
||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
@ -257,17 +260,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
|
||||
def get_cos_sin(self, position_ids: torch.Tensor, max_s: int, dtype: torch.dtype):
|
||||
"""
|
||||
Return cos and sin for the asked position ids
|
||||
"""
|
||||
if SYSTEM == "rocm":
|
||||
# For RoCm, we always use float cos/sin to avoid a cast.
|
||||
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
|
||||
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
|
||||
dtype = torch.float32
|
||||
|
||||
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
|
||||
def get_cos_sin(self, position_ids: torch.Tensor):
|
||||
|
||||
cos = torch.index_select(self._cos_cached, 0, position_ids)
|
||||
sin = torch.index_select(self._sin_cached, 0, position_ids)
|
||||
@ -383,7 +376,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(PositionRotaryEmbedding):
|
||||
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
super().__init__(inv_freq, scaling_factor, max_position_embeddings)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
@ -461,7 +454,9 @@ class YarnPositionRotaryEmbedding(PositionRotaryEmbedding):
|
||||
mscale_all_dim: float,
|
||||
):
|
||||
inv_freq = _create_inv_freq(dim, base, device)
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
super().__init__(
|
||||
inv_freq, scaling_factor, max_position_embeddings * self.scaling_factor
|
||||
)
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
@ -546,3 +541,44 @@ def apply_llama3_scaling(
|
||||
new_freqs.append((1 - smooth) * freq / scaling_factor + smooth * freq)
|
||||
|
||||
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
||||
|
||||
|
||||
class RotaryPositionEmbeddingMultimodalSections(PositionRotaryEmbedding):
|
||||
def __init__(self, inv_freq: torch.Tensor, scaling_factor: float, sections: list):
|
||||
super().__init__(inv_freq, scaling_factor)
|
||||
self.sections = sections
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
self.section_indices = (
|
||||
torch.arange(len(self.sections))
|
||||
.repeat_interleave(torch.tensor(self.sections))
|
||||
.view(1, 1, -1)
|
||||
.to(inv_freq.device)
|
||||
)
|
||||
|
||||
def _update_cos_sin_cache(
|
||||
self, dtype: torch.dtype, device: torch.device, seqlen: int
|
||||
):
|
||||
# always cache the cos/sin for the full sequence length to avoid
|
||||
# recomputing if the sequence length is smaller than the cached one
|
||||
if (
|
||||
seqlen > self._seq_len_cached
|
||||
or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype
|
||||
):
|
||||
self._seq_len_cached = seqlen
|
||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.outer(t, self.inv_freq.to(device=t.device))
|
||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||
self._sections = self.section_indices.expand(seqlen, -1, -1)
|
||||
|
||||
def get_cos_sin(
|
||||
self,
|
||||
position_ids: torch.Tensor,
|
||||
):
|
||||
slen = position_ids.shape[0]
|
||||
|
||||
cos = self._cos_cached[position_ids].gather(1, self._sections[:slen])
|
||||
sin = self._sin_cached[position_ids].gather(1, self._sections[:slen])
|
||||
return cos, sin
|
||||
|
@ -2,10 +2,8 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from typing import Iterable, List
|
||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
import habana_frameworks.torch as htorch
|
||||
|
||||
|
||||
class LayerConcat(torch.nn.Module):
|
||||
@ -90,11 +88,7 @@ class TensorParallelHead(SuperLayer):
|
||||
local_out = gather_input.T
|
||||
|
||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
else:
|
||||
htorch.core.mark_step()
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
world_out, gather_input, group=self.process_group
|
||||
)
|
||||
@ -107,9 +101,8 @@ class TensorParallelHead(SuperLayer):
|
||||
world_output = [
|
||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||
]
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
else:
|
||||
|
||||
htorch.core.mark_step()
|
||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||
world_output = torch.cat(world_output, dim=-1)
|
||||
return world_output
|
||||
@ -202,9 +195,10 @@ class TensorParallelRowLinear(SuperLayer):
|
||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||
out = super().forward(input)
|
||||
if self.process_group.size() > 1 and reduce:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
||||
@ -242,8 +236,9 @@ class TensorParallelEmbedding(torch.nn.Module):
|
||||
)
|
||||
out = torch.nn.functional.embedding(input, self.weight)
|
||||
if self.reduce and self.process_group.size() > 1:
|
||||
if SYSTEM == "ipex":
|
||||
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||
else:
|
||||
# FIXME(kzawora): this is a workaround for a bug in Habana PT bridge
|
||||
# occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used
|
||||
# (which is required for tensor parallel HPUGraph inference)
|
||||
htorch.core.mark_step()
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
return out
|
||||
|
@ -1,3 +1,5 @@
|
||||
# ruff: noqa: F821
|
||||
# the above line disables the `undefined-name` rule for the model type variables
|
||||
import torch
|
||||
import os
|
||||
|
||||
@ -8,6 +10,7 @@ from huggingface_hub import hf_hub_download, HfApi
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
import enum
|
||||
|
||||
# Needed to properly setup habana_frameworks
|
||||
|
||||
@ -35,9 +38,313 @@ from text_generation_server.utils.adapter import (
|
||||
)
|
||||
from text_generation_server.adapters.lora import LoraWeights
|
||||
|
||||
|
||||
from text_generation_server.utils.log import log_master
|
||||
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
|
||||
|
||||
__all__ = [
|
||||
"Model",
|
||||
"CausalLM",
|
||||
"Seq2SeqLM",
|
||||
"get_model_with_lora_adapters",
|
||||
]
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
|
||||
FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
|
||||
|
||||
FLASH_ATTENTION = False
|
||||
if ATTENTION == "paged":
|
||||
FLASH_ATTENTION = True
|
||||
|
||||
try:
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v2_modeling import (
|
||||
FlashDeepseekV2ForCausalLM,
|
||||
DeepseekV2Config,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_deepseek_v3_modeling import (
|
||||
FlashDeepseekV3ForCausalLM,
|
||||
DeepseekV3Config,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
|
||||
FlashCohereForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_dbrx_modeling import (
|
||||
FlashDbrxForCausalLM,
|
||||
DbrxConfig,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
||||
RWConfig,
|
||||
FlashRWForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
FlashGPTNeoXForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.pali_gemma import (
|
||||
PaliGemmaBatch,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import (
|
||||
PaliGemmaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_phi_modeling import (
|
||||
FlashPhiForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.mllama import (
|
||||
MllamaForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.llava_next import (
|
||||
LlavaNextForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
||||
FlashSantacoderForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import (
|
||||
FlashStarcoder2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mixtral_modeling import (
|
||||
FlashMixtralForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gpt2_modeling import (
|
||||
FlashGPT2ForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_gptj_modeling import (
|
||||
FlashGPTJForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||
Idefics3ForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||
Qwen2VLForConditionalGeneration,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.qwen2_5_vl import (
|
||||
Qwen2_5VLForConditionalGeneration,
|
||||
Qwen2_5_VLConfig,
|
||||
Qwen2_5_VLProcessor,
|
||||
)
|
||||
from text_generation_server.layers.attention import SUPPORTS_WINDOWING
|
||||
except ImportError as e:
|
||||
log_master(logger.warning, f"Could not import Flash Attention enabled models: {e}")
|
||||
SUPPORTS_WINDOWING = False
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashCausalLM)
|
||||
__all__.append(IdeficsCausalLM)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
DEEPSEEK_V2 = {
|
||||
"type": "deepseek_v2",
|
||||
"name": "Deepseek V2",
|
||||
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V2",
|
||||
}
|
||||
DEEPSEEK_V3 = {
|
||||
"type": "deepseek_v3",
|
||||
"name": "Deepseek V3",
|
||||
"url": "https://huggingface.co/deepseek-ai/DeepSeek-V3",
|
||||
}
|
||||
IDEFICS2 = {
|
||||
"type": "idefics2",
|
||||
"name": "Idefics 2",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||
"multimodal": True,
|
||||
}
|
||||
IDEFICS3 = {
|
||||
"type": "idefics3",
|
||||
"name": "Idefics 3",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||
"multimodal": True,
|
||||
}
|
||||
LLAVA_NEXT = {
|
||||
"type": "llava_next",
|
||||
"name": "Llava Next (1.6)",
|
||||
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
|
||||
"multimodal": True,
|
||||
}
|
||||
LLAMA = {
|
||||
"type": "llama",
|
||||
"name": "Llama",
|
||||
"url": "https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f",
|
||||
}
|
||||
PHI3 = {
|
||||
"type": "phi3",
|
||||
"name": "Phi 3",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
GRANITE = {
|
||||
"type": "granite",
|
||||
"name": "Granite",
|
||||
"url": "https://huggingface.co/ibm-granite/granite-3.0-8b-instruct",
|
||||
}
|
||||
GEMMA = {
|
||||
"type": "gemma",
|
||||
"name": "Gemma",
|
||||
"url": "https://huggingface.co/google/gemma-7b",
|
||||
}
|
||||
PALIGEMMA = {
|
||||
"type": "paligemma",
|
||||
"name": "PaliGemma",
|
||||
"url": "https://huggingface.co/google/paligemma-3b-pt-224",
|
||||
}
|
||||
GEMMA2 = {
|
||||
"type": "gemma2",
|
||||
"name": "Gemma2",
|
||||
"url": "https://huggingface.co/collections/google/gemma-2-release-667d6600fd5220e7b967f315",
|
||||
}
|
||||
COHERE = {
|
||||
"type": "cohere",
|
||||
"name": "Cohere",
|
||||
"url": "https://huggingface.co/CohereForAI/c4ai-command-r-plus",
|
||||
}
|
||||
DBRX = {
|
||||
"type": "dbrx",
|
||||
"name": "Dbrx",
|
||||
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||
}
|
||||
MAMBA = {
|
||||
"type": "mamba",
|
||||
"name": "Mamba",
|
||||
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||
}
|
||||
MISTRAL = {
|
||||
"type": "mistral",
|
||||
"name": "Mistral",
|
||||
"url": "https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407",
|
||||
}
|
||||
MIXTRAL = {
|
||||
"type": "mixtral",
|
||||
"name": "Mixtral",
|
||||
"url": "https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1",
|
||||
}
|
||||
GPT_BIGCODE = {
|
||||
"type": "gpt_bigcode",
|
||||
"name": "Gpt Bigcode",
|
||||
"url": "https://huggingface.co/bigcode/gpt_bigcode-santacoder",
|
||||
}
|
||||
PHI = {
|
||||
"type": "phi",
|
||||
"name": "Phi",
|
||||
"url": "https://huggingface.co/microsoft/phi-1_5",
|
||||
}
|
||||
PHI_MOE = {
|
||||
"type": "phimoe",
|
||||
"name": "PhiMoe",
|
||||
"url": "https://huggingface.co/microsoft/Phi-3.5-MoE-instruct",
|
||||
}
|
||||
BAICHUAN = {
|
||||
"type": "baichuan",
|
||||
"name": "Baichuan",
|
||||
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
|
||||
}
|
||||
FALCON = {
|
||||
"type": "falcon",
|
||||
"name": "Falcon",
|
||||
"url": "https://huggingface.co/tiiuae/falcon-7b-instruct",
|
||||
}
|
||||
STARCODER2 = {
|
||||
"type": "starcoder2",
|
||||
"name": "StarCoder 2",
|
||||
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
|
||||
}
|
||||
QWEN2 = {
|
||||
"type": "qwen2",
|
||||
"name": "Qwen 2",
|
||||
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
|
||||
}
|
||||
QWEN2_VL = {
|
||||
"type": "qwen2_vl",
|
||||
"name": "Qwen 2 VL",
|
||||
"url": "https://huggingface.co/collections/Qwen/qwen2-vl-66cee7455501d7126940800d",
|
||||
}
|
||||
QWEN2_5_VL = {
|
||||
"type": "qwen2_5_vl",
|
||||
"name": "Qwen 2.5 VL",
|
||||
"url": "https://huggingface.co/collections/Qwen/qwen25-66e81a666513e518adb90d9e",
|
||||
}
|
||||
OPT = {
|
||||
"type": "opt",
|
||||
"name": "Opt",
|
||||
"url": "https://huggingface.co/facebook/opt-6.7b",
|
||||
}
|
||||
T5 = {
|
||||
"type": "t5",
|
||||
"name": "T5",
|
||||
"url": "https://huggingface.co/google/flan-t5-xxl",
|
||||
}
|
||||
GALACTICA = {
|
||||
"type": "galactica",
|
||||
"name": "Galactica",
|
||||
"url": "https://huggingface.co/facebook/galactica-120b",
|
||||
}
|
||||
SANTACODER = {
|
||||
"type": "santacoder",
|
||||
"name": "SantaCoder",
|
||||
"url": "https://huggingface.co/bigcode/santacoder",
|
||||
}
|
||||
BLOOM = {
|
||||
"type": "bloom",
|
||||
"name": "Bloom",
|
||||
"url": "https://huggingface.co/bigscience/bloom-560m",
|
||||
}
|
||||
MPT = {
|
||||
"type": "mpt",
|
||||
"name": "Mpt",
|
||||
"url": "https://huggingface.co/mosaicml/mpt-7b-instruct",
|
||||
}
|
||||
GPT2 = {
|
||||
"type": "gpt2",
|
||||
"name": "Gpt2",
|
||||
"url": "https://huggingface.co/openai-community/gpt2",
|
||||
}
|
||||
GPT_NEOX = {
|
||||
"type": "gpt_neox",
|
||||
"name": "Gpt Neox",
|
||||
"url": "https://huggingface.co/EleutherAI/gpt-neox-20b",
|
||||
}
|
||||
GPTJ = {
|
||||
"type": "gptj",
|
||||
"name": "Gptj",
|
||||
"url": "https://huggingface.co/EleutherAI/gpt-j-6b",
|
||||
}
|
||||
IDEFICS = {
|
||||
"type": "idefics",
|
||||
"name": "Idefics",
|
||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||
"multimodal": True,
|
||||
}
|
||||
MLLAMA = {
|
||||
"type": "mllama",
|
||||
"name": "Mllama",
|
||||
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"multimodal": True,
|
||||
}
|
||||
|
||||
|
||||
__GLOBALS = locals()
|
||||
for data in ModelType:
|
||||
__GLOBALS[data.name] = data.value["type"]
|
||||
|
||||
# Disable gradients
|
||||
torch.set_grad_enabled(False)
|
||||
@ -54,7 +361,7 @@ def get_model(
|
||||
trust_remote_code: bool,
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
adapt_transformers_to_gaudi()
|
||||
global FLASH_ATTENTION
|
||||
|
||||
if speculate is not None:
|
||||
set_speculate(speculate)
|
||||
@ -178,7 +485,389 @@ def get_model(
|
||||
|
||||
if model_type == "gpt_bigcode":
|
||||
return StarCoder(model_id=model_id, revision=revision, dtype=dtype)
|
||||
kv_cache_dtype = dtype
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
if model_type == DEEPSEEK_V2:
|
||||
head_size = max(
|
||||
config_dict.get("qk_nope_dim", 128)
|
||||
+ config_dict.get("qk_rope_dim", 64),
|
||||
config_dict.get("v_head_dim", 128),
|
||||
)
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDeepseekV2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
default_dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DeepseekV2Config,
|
||||
head_size=head_size,
|
||||
)
|
||||
elif model_type == DEEPSEEK_V3:
|
||||
head_size = max(
|
||||
config_dict.get("qk_nope_dim", 128)
|
||||
+ config_dict.get("qk_rope_dim", 64),
|
||||
config_dict.get("v_head_dim", 128),
|
||||
)
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDeepseekV3ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
default_dtype=torch.bfloat16,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DeepseekV3Config,
|
||||
head_size=head_size,
|
||||
)
|
||||
|
||||
elif (
|
||||
model_type == GPT_BIGCODE
|
||||
or model_type == GPT2
|
||||
and model_id.startswith("bigcode/")
|
||||
):
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashSantacoderForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
aliases={"transformer.wte.weight": ["lm_head.weight"]},
|
||||
num_kv_heads=1,
|
||||
)
|
||||
elif model_type == GPT2:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPT2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == GPTJ:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPTJForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == GPT_NEOX:
|
||||
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
||||
GPTNeoXConfig,
|
||||
)
|
||||
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGPTNeoXForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=GPTNeoXConfig,
|
||||
)
|
||||
elif model_type == PHI:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashPhiForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == PHI_MOE:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
config_class=PhiMoEConfig,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == LLAMA or model_type == PHI3 or model_type == GRANITE:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == BAICHUAN:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashLlamaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == GEMMA:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemmaForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == GEMMA2:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashGemma2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == COHERE:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashCohereForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == DBRX:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashDbrxForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Dbrx works better in bfloat16.
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=DbrxConfig,
|
||||
)
|
||||
elif (
|
||||
model_type in ["RefinedWeb", "RefinedWebModel", FALCON]
|
||||
and not sharded
|
||||
and not config_dict.get("alibi", False)
|
||||
):
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashRWForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
aliases={
|
||||
"lm_head.weight": ["transformer.word_embeddings.weight"],
|
||||
"transformer.word_embeddings.weight": ["lm_head.weight"],
|
||||
},
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=RWConfig,
|
||||
)
|
||||
elif model_type == MISTRAL:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMistralForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == MIXTRAL:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashMixtralForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == STARCODER2:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=FlashStarcoder2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == QWEN2:
|
||||
return FlashCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2ForCausalLM,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == IDEFICS:
|
||||
return IdeficsCausalLM(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == QWEN2_VL:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2VLForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == QWEN2_5_VL:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Qwen2_5VLForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
config_class=Qwen2_5_VLConfig,
|
||||
processor_class=Qwen2_5_VLProcessor,
|
||||
)
|
||||
elif model_type == MLLAMA:
|
||||
return MllamaCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=MllamaForConditionalGeneration,
|
||||
batch_class=MllamaCausalLMBatch,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
)
|
||||
elif model_type == IDEFICS2:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics2ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}},
|
||||
)
|
||||
elif model_type == IDEFICS3:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=Idefics3ForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||
)
|
||||
elif model_type == PALIGEMMA:
|
||||
return VlmCausalLM(
|
||||
model_id=model_id,
|
||||
model_class=PaliGemmaForConditionalGeneration,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
# Works better for these models
|
||||
default_dtype=torch.bfloat16,
|
||||
trust_remote_code=trust_remote_code,
|
||||
lora_adapter_ids=lora_adapter_ids,
|
||||
batch_class=PaliGemmaBatch,
|
||||
)
|
||||
elif model_type == LLAVA_NEXT:
|
||||
return VlmCausalLM(
|
||||
model_class=LlavaNextForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
speculator=speculator,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
adapt_transformers_to_gaudi()
|
||||
if model_type == "bloom":
|
||||
return BLOOM(
|
||||
model_id=model_id,
|
||||
|
@ -377,7 +377,7 @@ class BloomAttention(nn.Module):
|
||||
past_value.view(-1, *past_value.shape[-2:]),
|
||||
)
|
||||
|
||||
if CUSTOM_KERNELS_ENABLED:
|
||||
if CUSTOM_KERNELS_ENABLED and attention_mask.shape[-1] < 4096:
|
||||
assert self.training is False, "Only foward pass was implemented"
|
||||
assert (
|
||||
attention_mask.shape[-1] < 4096
|
||||
@ -580,7 +580,7 @@ class BloomPreTrainedModel(PreTrainedModel):
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_bloom_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
|
@ -28,10 +28,9 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
@ -39,7 +38,6 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -47,11 +45,10 @@ from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.utils.weights import UnquantizedWeight
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
class CohereRotary(PositionRotaryEmbedding):
|
||||
@ -63,38 +60,25 @@ class CohereRotary(PositionRotaryEmbedding):
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if SYSTEM == "cuda":
|
||||
import rotary_emb
|
||||
|
||||
q1 = query[..., ::2]
|
||||
q2 = query[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
k1 = key[..., ::2]
|
||||
k2 = key[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
head_size = query.shape[-1]
|
||||
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||
|
||||
|
||||
class CohereLayerNorm(nn.Module):
|
||||
@ -107,7 +91,6 @@ class CohereLayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if hidden_states.shape[-1] > 8192 or SYSTEM != "cuda":
|
||||
hidden_states = hidden_states.reshape(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
@ -121,36 +104,6 @@ class CohereLayerNorm(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
return hidden_states.to(input_dtype)
|
||||
|
||||
(
|
||||
hidden_states,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
None,
|
||||
self.ones,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
|
||||
# Required to apply one weight matrix per head
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.weight.shape[0], self.weight.shape[1]
|
||||
)
|
||||
hidden_states = self.weight * hidden_states
|
||||
hidden_states = hidden_states.view(-1, self.weight.shape[1])
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
@ -229,6 +182,7 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.use_qk_norm = config.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
@ -291,30 +245,37 @@ class FlashCohereAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -20,17 +20,13 @@ from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple, Any
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
@ -48,6 +44,9 @@ from text_generation_server.layers.layernorm import (
|
||||
)
|
||||
|
||||
|
||||
moe_kernels = None
|
||||
|
||||
|
||||
class DbrxAttentionConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
@ -290,6 +289,7 @@ class DbrxAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -330,30 +330,37 @@ class DbrxAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -485,7 +492,8 @@ class BlockSparseMoE(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(x)
|
||||
out = fused_moe(
|
||||
|
||||
out = moe_kernels.fused_moe(
|
||||
x,
|
||||
self.wv1,
|
||||
self.w2,
|
||||
|
@ -33,21 +33,13 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
class DeepseekV2Config(PretrainedConfig):
|
||||
def __init__(
|
||||
@ -232,6 +224,8 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.kv_a_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
@ -260,7 +254,7 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
kv_cache: KVCache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
@ -321,30 +315,37 @@ class DeepseekV2Attention(torch.nn.Module):
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
@ -387,22 +388,6 @@ class DeepseekV2MLP(nn.Module):
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
dtype=hidden_states.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out, reduce=reduce)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
|
@ -0,0 +1,653 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, 2024 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
SpeculativeHead,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import KVCache, get_kv_scales
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding, get_mscale
|
||||
from text_generation_server.utils.weights import Weights
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=102400,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
moe_intermediate_size=1407,
|
||||
num_hidden_layers=30,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
n_shared_experts=2,
|
||||
n_routed_experts=160,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=1.0,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method="gready",
|
||||
n_group=8,
|
||||
topk_group=3,
|
||||
num_experts_per_tok=6,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=0,
|
||||
norm_topk_prob=False,
|
||||
scoring_func="softmax",
|
||||
aux_loss_alpha=0.001,
|
||||
seq_aux=True,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=100000,
|
||||
eos_token_id=100001,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
self.aux_loss_alpha = aux_loss_alpha
|
||||
self.seq_aux = seq_aux
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
tie_word_embeddings = kwargs.pop("tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
raise ValueError(
|
||||
"tie_word_embeddings is not supported for Deepseek V2 models."
|
||||
)
|
||||
|
||||
if ep_size != 1:
|
||||
raise ValueError(
|
||||
f"Currently only ep_size == 1 is supported for Deepseek V2 models, was {ep_size}"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights: Weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.q_lora_rank = config.q_lora_rank
|
||||
self.qk_nope_head_dim = config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||||
self.head_size = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
self.value_head_size = config.v_head_dim
|
||||
self.head_pad_size = max(self.head_size, self.value_head_size)
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.qk_rope_head_dim,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
mscale = get_mscale(
|
||||
self.rotary_emb.scaling_factor, self.rotary_emb.mscale_all_dim
|
||||
)
|
||||
self.softmax_scale = self.head_size**-0.5 * mscale * mscale
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
if self.q_lora_rank is None:
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
else:
|
||||
self.q_a_proj = get_linear(
|
||||
weight=weights.get_weights(f"{prefix}.q_a_proj"),
|
||||
bias=(
|
||||
weights.get_tensor(f"{prefix}.q_a_proj.bias")
|
||||
if config.attention_bias
|
||||
else None
|
||||
),
|
||||
)
|
||||
self.q_a_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.q_a_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.q_b_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
weights=weights,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = get_linear(
|
||||
weight=weights.get_weights(f"{prefix}.kv_a_proj_with_mqa"),
|
||||
bias=(
|
||||
weights.get_tensor(f"{prefix}.kv_a_proj_with_mqa.bias")
|
||||
if config.attention_bias
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.kv_a_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.kv_a_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.kv_b_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
weights=weights,
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
if self.q_lora_rank is None:
|
||||
query = self.q_proj(hidden_states)
|
||||
else:
|
||||
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
_, query_pe = torch.split(
|
||||
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
|
||||
compressed_kv, key_pe = torch.split(
|
||||
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
|
||||
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
|
||||
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
|
||||
)
|
||||
|
||||
key_nope, value = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
|
||||
)
|
||||
|
||||
batch_size, heads, head_dim = query_pe.shape
|
||||
query_pe = (
|
||||
query_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||
.transpose(2, 3)
|
||||
.reshape(batch_size, heads, head_dim)
|
||||
)
|
||||
batch_size, heads, head_dim = key_pe.shape
|
||||
key_pe = (
|
||||
key_pe.view(batch_size, heads, head_dim // 2, 2)
|
||||
.transpose(2, 3)
|
||||
.reshape(batch_size, heads, head_dim)
|
||||
)
|
||||
self.rotary_emb(query_pe, key_pe, cos, sin)
|
||||
|
||||
query[..., self.qk_nope_head_dim :] = query_pe
|
||||
key = torch.empty_like(query)
|
||||
key[..., : self.qk_nope_head_dim] = key_nope
|
||||
key[..., self.qk_nope_head_dim :] = key_pe
|
||||
|
||||
# We need to pad the heads because Flash Attention does not support
|
||||
# qk and v with different head sizes.
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, self.head_pad_size - self.head_size), value=0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, self.head_pad_size - self.value_head_size), value=0
|
||||
)
|
||||
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Remove padding.
|
||||
attn_output = attn_output[..., : self.value_head_size]
|
||||
|
||||
return self.o_proj(
|
||||
attn_output.reshape(-1, self.num_heads * self.value_head_size)
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
def __init__(self, prefix: str, config, weights, intermediate_size: int):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
if self.hidden_act != "silu":
|
||||
# Bail out because MoE only supports silu.
|
||||
raise NotImplementedError(
|
||||
"Currently only `silu` is supported as an activation for Deepseek V2."
|
||||
)
|
||||
self.act = ACT2FN[self.hidden_act]
|
||||
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.intermediate_size = intermediate_size // weights.process_group.size()
|
||||
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, reduce: bool = True):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], reduce=reduce
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix,
|
||||
config: DeepseekV3Config,
|
||||
moe_layer_cls: Type[MoELayer],
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = config.hidden_size
|
||||
self.moe_intermediate_size = (
|
||||
config.moe_intermediate_size // weights.process_group.size()
|
||||
)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
# Gating
|
||||
self.gate = FastLinear.load(config, f"{prefix}.gate", weights, bias=False)
|
||||
|
||||
if config.topk_method == "noaux_tc":
|
||||
self.gate.e_score_correction_bias = torch.zeros(
|
||||
config.n_routed_experts, device=weights.device
|
||||
)
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.moe_layer = moe_layer_cls(
|
||||
prefix=f"{prefix}.experts",
|
||||
n_experts=config.n_routed_experts,
|
||||
n_expert_group=config.n_group,
|
||||
renormalize=config.norm_topk_prob,
|
||||
topk=config.num_experts_per_tok,
|
||||
topk_group=config.topk_group,
|
||||
weights=weights,
|
||||
scoring_func=config.scoring_func,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
assert isinstance(self.moe_layer, MoELayer)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
self.shared_experts = DeepseekV3MLP(
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
config=config,
|
||||
weights=weights,
|
||||
intermediate_size=config.moe_intermediate_size
|
||||
* config.n_shared_experts,
|
||||
)
|
||||
else:
|
||||
self.shared_experts = None
|
||||
|
||||
self.process_group = weights.process_group
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(x, reduce=False)
|
||||
else:
|
||||
shared_output = None
|
||||
|
||||
router_logits = self.gate(x)
|
||||
|
||||
out = self.moe_layer(x, gating_output=router_logits)
|
||||
|
||||
if shared_output is not None:
|
||||
out = out + shared_output
|
||||
|
||||
# Reduce sum
|
||||
if self.process_group.size() > 1:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out.view(*x.shape)
|
||||
|
||||
|
||||
class DeepseekV3Layer(nn.Module):
|
||||
def __init__(self, prefix, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"{prefix}.layers.{layer_id}"
|
||||
|
||||
self.self_attn = DeepseekV3Attention(
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
if (
|
||||
config.n_routed_experts is not None
|
||||
and layer_id >= config.first_k_dense_replace
|
||||
and layer_id % config.moe_layer_freq == 0
|
||||
):
|
||||
moe_layer_cls = (
|
||||
SparseMoELayer
|
||||
if SparseMoELayer.is_supported(weights)
|
||||
else DenseMoELayer
|
||||
)
|
||||
self.mlp = DeepseekV3MoE(f"{prefix}.mlp", config, moe_layer_cls, weights)
|
||||
else:
|
||||
self.mlp = DeepseekV3MLP(
|
||||
prefix=f"{prefix}.mlp",
|
||||
config=config,
|
||||
weights=weights,
|
||||
intermediate_size=config.intermediate_size,
|
||||
)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
cu_seqlen_prefill: torch.Tensor,
|
||||
kv_cache,
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
):
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, residual = self.post_attention_layernorm(
|
||||
attn_output, residual
|
||||
)
|
||||
|
||||
output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return output, residual
|
||||
|
||||
|
||||
class DeepseekV3Model(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights: Weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DeepseekV3Layer(
|
||||
prefix,
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashDeepseekV3ForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights: Weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = DeepseekV3Model(
|
||||
"model" if not prefix else f"{prefix}.model", config, weights
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -28,7 +28,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -40,7 +39,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -208,6 +207,7 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -253,19 +253,25 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
causal=self.causal,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.window_size,
|
||||
softcap=self.softcap,
|
||||
)
|
||||
@ -273,14 +279,14 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
softcap=self.softcap,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
|
@ -28,9 +28,7 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
PREFILL_IN_KV_CACHE,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -39,6 +37,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -187,6 +186,7 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -224,31 +224,38 @@ class FlashGemmaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=self.causal,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -24,11 +24,9 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -38,6 +36,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
|
||||
|
||||
def load_qkv(config, prefix: str, weights, head_size, num_heads):
|
||||
@ -195,6 +194,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
head_size=self.head_size,
|
||||
num_heads=self.num_heads,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = load_row(
|
||||
config,
|
||||
@ -224,30 +224,37 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -24,11 +24,10 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from typing import Optional, List, Tuple
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -38,13 +37,16 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import (
|
||||
PositionRotaryEmbedding,
|
||||
)
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
from habana_frameworks.torch.hpex.kernels import (
|
||||
RotaryPosEmbeddingMode,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights):
|
||||
@ -78,39 +80,25 @@ class GPTJRotary(PositionRotaryEmbedding):
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
):
|
||||
# Such controlflows may add some overhead.
|
||||
if SYSTEM == "cuda":
|
||||
import rotary_emb
|
||||
|
||||
q1 = query[..., ::2]
|
||||
q2 = query[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
|
||||
|
||||
k1 = key[..., ::2]
|
||||
k2 = key[..., 1::2]
|
||||
|
||||
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
|
||||
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
|
||||
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
|
||||
|
||||
num_tokens = query.shape[0]
|
||||
head_size = query.shape[-1]
|
||||
rope_mode = RotaryPosEmbeddingMode.PAIRWISE
|
||||
sin = torch.repeat_interleave(sin, 2, dim=-1)
|
||||
cos = torch.repeat_interleave(cos, 2, dim=-1)
|
||||
rotary_dim = cos.shape[-1]
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode)
|
||||
query.copy_(torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape))
|
||||
|
||||
# Inplace operation, updating query and key.
|
||||
ops.rotary_embedding(query, key, head_size, cos, sin, False)
|
||||
elif SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
ipex.llm.functional.rotary_embedding(
|
||||
query, key, sin, cos, query.size(-1), False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode)
|
||||
key.copy_(torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape))
|
||||
|
||||
|
||||
class FlashGPTJAttention(torch.nn.Module):
|
||||
@ -140,6 +128,7 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
prefix=prefix,
|
||||
weights=weights,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = load_row(
|
||||
config,
|
||||
@ -186,30 +175,37 @@ class FlashGPTJAttention(torch.nn.Module):
|
||||
else:
|
||||
self.rotary_emb(query, key, cos, sin)
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=key,
|
||||
value=value,
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key,
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else value,
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -27,14 +27,16 @@ import torch.distributed
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention import (
|
||||
KVCache,
|
||||
get_kv_scales,
|
||||
)
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
HPUPagedAttentionMetadata,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
@ -57,15 +59,6 @@ from text_generation_server.utils.weights import (
|
||||
)
|
||||
from text_generation_server.layers.fp8 import HybridFP8UnquantLoader
|
||||
|
||||
if SYSTEM != "ipex":
|
||||
pass
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
def load_attention(config, prefix: str, weights, layer_id):
|
||||
# Only defined in granite.
|
||||
@ -157,7 +150,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
# `config.attention_multiplier` is used in Granite
|
||||
self.softmax_scale = getattr(
|
||||
config, "attention_multiplier", self.head_size**-0.5
|
||||
)
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
@ -177,11 +173,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
bias=getattr(config, "attention_bias", False),
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
@ -202,12 +200,13 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
kv_cache: KVCache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
@ -222,30 +221,42 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_scales=self.kv_scales,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
@ -363,26 +374,6 @@ class LlamaMLP(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
and self.hidden_size
|
||||
!= 16384 # TODO: Temporary workaround for `LLMM_Silu` kernel not working with LLama3.1 405B; needs refactoring once fixed.
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
dtype=hidden_states.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(
|
||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
||||
)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
@ -408,7 +399,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
if SparseMoELayer.is_supported(weights)
|
||||
else DenseMoELayer
|
||||
)
|
||||
self.dense = Phi3MoE(
|
||||
self.mlp = Phi3MoE(
|
||||
f"{prefix}.block_sparse_moe", config, moe_layer_cls, weights
|
||||
)
|
||||
# with moe the layernorms are are not rmsnorms and they have bias
|
||||
@ -423,7 +414,7 @@ class FlashLlamaLayer(nn.Module):
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
else:
|
||||
self.dense = LlamaMLP(
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
)
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
@ -437,6 +428,11 @@ class FlashLlamaLayer(nn.Module):
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
# Used in Granite
|
||||
# This could eventually be baked into the weights like we do for the embeddings/lm_head
|
||||
# but this would mean modifying the lora code
|
||||
self.residual_multiplier = getattr(config, "residual_multiplier", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
@ -448,9 +444,10 @@ class FlashLlamaLayer(nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -464,16 +461,20 @@ class FlashLlamaLayer(nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if self.residual_multiplier is not None:
|
||||
attn_output *= self.residual_multiplier
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.dense(normed_attn_res_output, adapter_data)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
if self.residual_multiplier is not None:
|
||||
mlp_output *= self.residual_multiplier
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -493,9 +494,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=0,
|
||||
prefix=(
|
||||
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
|
||||
),
|
||||
prefix=f"{prefix}.layers.0",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
@ -511,11 +510,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers.append(
|
||||
FlashLlamaCrossLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
@ -524,11 +519,7 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
@ -539,18 +530,14 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
self.layers.append(
|
||||
FlashLlamaLayer(
|
||||
index=last_layer_id,
|
||||
prefix=(
|
||||
f"model.layers.{last_layer_id}"
|
||||
if not prefix
|
||||
else f"{prefix}.model.layers.{last_layer_id}"
|
||||
),
|
||||
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
)
|
||||
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
||||
prefix=f"{prefix}.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
@ -570,19 +557,16 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
cross_attention_states=None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(position_ids)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
@ -596,9 +580,10 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
adapter_data,
|
||||
cross_attention_states,
|
||||
prefill_cache_indices,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -607,31 +592,51 @@ class FlashLlamaModel(torch.nn.Module):
|
||||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix: str, config, weights):
|
||||
def __init__(self, prefix: str, config, weights, name=None):
|
||||
if name is None:
|
||||
name = "model"
|
||||
super().__init__()
|
||||
|
||||
with no_fp8(weights):
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens"
|
||||
f"{name}.embed_tokens"
|
||||
if not prefix
|
||||
else f"{prefix}.model.embed_tokens"
|
||||
else f"{prefix}.{name}.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
self.model = FlashLlamaModel(
|
||||
prefix=name if not prefix else f"{prefix}.{name}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
# Used in Granite
|
||||
embedding_multiplier = getattr(config, "embedding_multiplier", None)
|
||||
if embedding_multiplier is not None:
|
||||
self.embed_tokens.weight.data *= embedding_multiplier
|
||||
prefix = suffix if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||
with no_fp8(weights):
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
prefix,
|
||||
weights,
|
||||
)
|
||||
|
||||
# Used in Granite
|
||||
self.logits_scaling = getattr(config, "logits_scaling", None)
|
||||
if self.logits_scaling is not None and self.lm_head.head is not None:
|
||||
try:
|
||||
# Scale the weights directly
|
||||
self.lm_head.head.linear.weight.data /= self.logits_scaling
|
||||
self.logits_scaled = True
|
||||
except Exception:
|
||||
self.logits_scaled = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -641,11 +646,11 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states=None,
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
@ -656,13 +661,19 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
adapter_data=adapter_data,
|
||||
cross_attention_states=cross_attention_states,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
|
||||
# Used in Granite
|
||||
if self.logits_scaling is not None and not self.logits_scaled:
|
||||
logits /= self.logits_scaling
|
||||
if speculative_logits is not None:
|
||||
speculative_logits /= self.logits_scaling
|
||||
|
||||
return logits, speculative_logits
|
||||
|
@ -26,11 +26,10 @@ from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -41,20 +40,12 @@ from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
)
|
||||
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
try:
|
||||
from vllm import _custom_C
|
||||
except Exception as e:
|
||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
class MistralConfig(PretrainedConfig):
|
||||
model_type = "mistral"
|
||||
|
||||
@ -160,6 +151,7 @@ class MistralAttention(torch.nn.Module):
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -210,33 +202,38 @@ class MistralAttention(torch.nn.Module):
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(
|
||||
@ -300,24 +297,6 @@ class MistralMLP(nn.Module):
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.dtype == torch.float16
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
self.intermediate_size,
|
||||
dtype=hidden_states.dtype,
|
||||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(
|
||||
self.gate_up_proj.base_layer.linear.weight, hidden_states, out, 8
|
||||
)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(
|
||||
|
@ -37,9 +37,8 @@ from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers.moe import DenseMoELayer, MoELayer, SparseMoELayer
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
@ -215,6 +214,7 @@ class MixtralAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
@ -258,33 +258,38 @@ class MixtralAttention(torch.nn.Module):
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -29,7 +29,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -39,7 +38,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -132,6 +131,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
head_size=self.head_size,
|
||||
hidden_size=self.hidden_size,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=True
|
||||
)
|
||||
@ -165,30 +165,37 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||
|
||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=qkv[:, 1],
|
||||
value=qkv[:, 2],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else qkv[:, 1],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else qkv[:, 2],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=qkv[:, 0],
|
||||
key=qkv[:, 1],
|
||||
value=qkv[:, 2],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
qkv[:, 0],
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -80,6 +80,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
@ -9,7 +9,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -19,7 +18,7 @@ from text_generation_server.layers import (
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
)
|
||||
@ -139,6 +138,7 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
# in llama the dense layer is called "o_proj" and has bias=False
|
||||
self.dense = TensorParallelRowLinear.load(
|
||||
@ -188,29 +188,36 @@ class FlashPhiAttention(torch.nn.Module):
|
||||
)
|
||||
|
||||
# Reshape key and value and cache
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_scales=self.kv_scales,
|
||||
kv_cache=kv_cache,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -17,7 +16,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastRMSNorm,
|
||||
@ -86,6 +85,8 @@ class Qwen2Attention(torch.nn.Module):
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
@ -128,33 +129,38 @@ class Qwen2Attention(torch.nn.Module):
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -229,7 +235,7 @@ class Qwen2Layer(nn.Module):
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
normed_hidden_states, residual = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
@ -244,15 +250,13 @@ class Qwen2Layer(nn.Module):
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
hidden_states = attn_output + residual
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states)
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
hidden_states = mlp_output + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2Model(torch.nn.Module):
|
||||
@ -264,9 +268,6 @@ class Qwen2Model(torch.nn.Module):
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Qwen2Layer(
|
||||
@ -290,7 +291,7 @@ class Qwen2Model(torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
@ -301,17 +302,17 @@ class Qwen2Model(torch.nn.Module):
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, true_max_s, hidden_states.dtype
|
||||
position_ids,
|
||||
true_max_s,
|
||||
hidden_states.dtype,
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
@ -325,7 +326,7 @@ class Qwen2Model(torch.nn.Module):
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -346,6 +347,12 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
prefix=f"{prefix}.{suffix}" if prefix else suffix,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.embed_tokens" if prefix else "model.embed_tokens",
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.max_past = config.sliding_window
|
||||
self.max_past_tensor = (
|
||||
torch.tensor(config.sliding_window, device=weights.device)
|
||||
@ -376,8 +383,10 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
||||
# kernel requires the true values
|
||||
seqlen = seqlen.clamp(max=self.max_past_tensor)
|
||||
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -12,13 +12,12 @@ from text_generation_server.layers import (
|
||||
TensorParallelRowLinear,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.layernorm import FastLayerNorm
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.attention import (
|
||||
attention,
|
||||
paged_attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
|
||||
@ -79,6 +78,7 @@ class RWConfig(PretrainedConfig):
|
||||
self.alibi = False
|
||||
self.rotary = True
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = 2048
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
# Backward compatibility with n_embed kwarg
|
||||
@ -160,6 +160,7 @@ class FlashRWAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
@ -200,30 +201,37 @@ class FlashRWAttention(torch.nn.Module):
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots)
|
||||
kv_cache.store(
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, 0],
|
||||
value=kv[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.dense(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
@ -278,6 +286,7 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
weights=weights,
|
||||
bias=config.bias,
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
self.dense = load_row(
|
||||
config, prefix=f"{prefix}.dense", weights=weights, bias=config.bias
|
||||
)
|
||||
@ -312,36 +321,37 @@ class FlashRWLargeAttention(torch.nn.Module):
|
||||
# Inplace rotary
|
||||
self.rotary_emb(query, torch.select(kv, dim=2, index=0), cos, sin)
|
||||
|
||||
reshape_and_cache(
|
||||
kv[:, :, 0].contiguous(),
|
||||
kv[:, :, 1].contiguous(),
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
slots,
|
||||
kv_cache.store(
|
||||
key=kv[:, :, 0].contiguous(),
|
||||
value=kv[:, :, 1].contiguous(),
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv[:, :, 0].contiguous(),
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv[:, :, 1].contiguous(),
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv[:, :, 0],
|
||||
value=kv[:, :, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.dense(
|
||||
|
@ -8,7 +8,6 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
@ -18,7 +17,7 @@ from text_generation_server.layers import (
|
||||
TensorParallelEmbedding,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.gptq import GPTQWeightsLoader
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
@ -259,6 +258,7 @@ class FlashMQAttention(torch.nn.Module):
|
||||
self.c_proj = load_row(
|
||||
config, prefix=f"{prefix}.c_proj", weights=weights, bias=True
|
||||
)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
self.kv_head_mapping = torch.zeros(
|
||||
self.num_heads, dtype=torch.int32, device=weights.device
|
||||
)
|
||||
@ -284,32 +284,37 @@ class FlashMQAttention(torch.nn.Module):
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key_value = key_value.view(-1, 2, 1, self.head_size)
|
||||
|
||||
reshape_and_cache(
|
||||
key_value[:, 0], key_value[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_cache.store(
|
||||
key=key_value[:, 0],
|
||||
value=key_value[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else key_value[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else key_value[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=key_value[:, 0],
|
||||
value=key_value[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.c_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
@ -29,17 +29,18 @@ from typing import Optional, List, Tuple
|
||||
from text_generation_server.layers.attention import (
|
||||
paged_attention,
|
||||
attention,
|
||||
reshape_and_cache,
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
)
|
||||
from text_generation_server.layers.attention import PREFILL_IN_KV_CACHE
|
||||
from text_generation_server.layers.attention.kv_cache import get_kv_scales
|
||||
from text_generation_server.layers.layernorm import (
|
||||
FastLayerNorm,
|
||||
FastRMSNorm,
|
||||
@ -110,17 +111,31 @@ class Starcoder2Config(PretrainedConfig):
|
||||
)
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
def load_attention(config, prefix, weights, layer_id):
|
||||
prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
sizes = [
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
]
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
base_layer = _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
base_layer = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
prefixes=prefixes,
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
return TensorParallelMultiAdapterLinear.load(
|
||||
base_layer=base_layer,
|
||||
layer_id=layer_id,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
@ -158,6 +173,7 @@ def _load_gqa(config, prefix: str, weights):
|
||||
class Starcoder2Attention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
@ -189,14 +205,23 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.kv_scales = get_kv_scales(weights, f"{prefix}")
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
bias=getattr(config, "use_bias", False),
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
index,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
@ -214,8 +239,9 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
@ -233,40 +259,47 @@ class Starcoder2Attention(torch.nn.Module):
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
kv_cache.store(
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
slots=slots,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attn_output = attention(
|
||||
query,
|
||||
kv_cache[0] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 0],
|
||||
kv_cache[1] if PREFILL_IN_KV_CACHE else kv_to_cache[:, 1],
|
||||
seqlen,
|
||||
block_tables,
|
||||
self.softmax_scale,
|
||||
query=query,
|
||||
key=kv_to_cache[:, 0],
|
||||
value=kv_to_cache[:, 1],
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=self.kv_scales,
|
||||
seqlen=seqlen,
|
||||
block_tables=block_tables,
|
||||
softmax_scale=self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
kv_cache,
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
seqlen,
|
||||
max_s,
|
||||
kv_scales=self.kv_scales,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class Starcoder2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, index):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
@ -280,27 +313,42 @@ class Starcoder2MLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.c_fc = TensorParallelColumnLinear.load(
|
||||
c_fc = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.c_proj = TensorParallelRowLinear.load(
|
||||
c_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.c_fc(hidden_states)
|
||||
self.c_fc = TensorParallelMultiAdapterLinear.load(
|
||||
c_fc,
|
||||
layer_id=index,
|
||||
layer_names=[f"{prefix}.c_fc"],
|
||||
sizes=[config.intermediate_size, config.intermediate_size],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
self.c_proj = TensorParallelAdapterRowLinear.load(
|
||||
c_proj,
|
||||
index,
|
||||
"c_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
hidden_states = self.c_fc(hidden_states, adapter_data)
|
||||
hidden_states = self.act(hidden_states)
|
||||
return self.c_proj(hidden_states)
|
||||
return self.c_proj(hidden_states, adapter_data)
|
||||
|
||||
|
||||
class Starcoder2GatedMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
@ -314,27 +362,47 @@ class Starcoder2GatedMLP(nn.Module):
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
prefixes = [f"{prefix}.gate_proj", f"{prefix}.up_proj"]
|
||||
sizes = [
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
]
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
prefixes=prefixes,
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
index,
|
||||
layer_names=prefixes,
|
||||
sizes=sizes,
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=config.use_bias,
|
||||
)
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
index,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
STARCODER2_NORMALIZATION_CLASSES = {
|
||||
@ -353,11 +421,11 @@ class Starcoder2Layer(nn.Module):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = Starcoder2Attention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights, index=layer_id
|
||||
)
|
||||
|
||||
self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=layer_id
|
||||
)
|
||||
|
||||
self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type].load(
|
||||
@ -384,6 +452,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
@ -399,6 +468,7 @@ class Starcoder2Layer(nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
@ -406,7 +476,7 @@ class Starcoder2Layer(nn.Module):
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
@ -453,6 +523,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
@ -476,6 +547,7 @@ class Starcoder2Model(torch.nn.Module):
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
@ -547,6 +619,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
|
||||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
@ -750,6 +750,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
@ -0,0 +1,584 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Idefics3 model."""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
)
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Idefics3VisionEmbeddings(nn.Module):
|
||||
"""
|
||||
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||
resolution.
|
||||
|
||||
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
self.patch_embedding.bias = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||
)
|
||||
|
||||
self.num_patches_per_side = self.image_size // self.patch_size
|
||||
self.num_patches = self.num_patches_per_side**2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
max_nb_patches_h, max_nb_patches_w = (
|
||||
max_im_h // self.patch_size,
|
||||
max_im_w // self.patch_size,
|
||||
)
|
||||
boundaries = torch.arange(
|
||||
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||
)
|
||||
position_ids = torch.full(
|
||||
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||
)
|
||||
|
||||
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||
nb_patches_w = p_attn_mask[0].sum()
|
||||
|
||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||
|
||||
bucket_coords_h = torch.bucketize(
|
||||
fractional_coords_h, boundaries, right=True
|
||||
)
|
||||
bucket_coords_w = torch.bucketize(
|
||||
fractional_coords_w, boundaries, right=True
|
||||
)
|
||||
|
||||
pos_ids = (
|
||||
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||
).flatten()
|
||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||
|
||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||
embeddings = embeddings + self.position_embedding(position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Idefics3VisionAttention(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = self.embed_dim // self.num_heads
|
||||
if self.head_size * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_size**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||
)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
self.head_size * self.num_heads,
|
||||
self.head_size * self.num_heads,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
query_states = query_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = key_states.shape[-2]
|
||||
attn_weights = (
|
||||
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||
)
|
||||
|
||||
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Idefics3VisionMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics3EncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics3VisionAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||
)
|
||||
self.mlp = Idefics3VisionMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics3Encoder(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Idefics3EncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics3VisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = Idefics3VisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.encoder = Idefics3Encoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
):
|
||||
batch_size = pixel_values.size(0)
|
||||
if patch_attention_mask is None:
|
||||
patch_size = self.config.patch_size
|
||||
patch_attention_mask = torch.ones(
|
||||
(
|
||||
batch_size,
|
||||
pixel_values.size(2) // patch_size,
|
||||
pixel_values.size(3) // patch_size,
|
||||
)
|
||||
)
|
||||
patch_attention_mask = patch_attention_mask.to(
|
||||
dtype=torch.bool, device=pixel_values.device
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||
)
|
||||
|
||||
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||
if not torch.any(~patch_attention_mask):
|
||||
patch_attention_mask = None
|
||||
else:
|
||||
patch_attention_mask = _prepare_4d_attention_mask(
|
||||
patch_attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class Idefics3SimpleMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
|
||||
output_size = config.text_config.hidden_size
|
||||
proj = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
|
||||
requires_grad=False,
|
||||
).to(weights.dtype)
|
||||
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||||
self.proj.weight = proj
|
||||
|
||||
def forward(self, x):
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class Idefics3Connector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
|
||||
self.scale_factor = config.scale_factor
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=2):
|
||||
bsz, seq, embed_dim = x.size()
|
||||
height = width = int(seq**0.5)
|
||||
x = x.view(bsz, height, width, embed_dim)
|
||||
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
x = x.reshape(
|
||||
bsz,
|
||||
int(width / scale_factor),
|
||||
int(height / scale_factor),
|
||||
embed_dim * (scale_factor**2),
|
||||
)
|
||||
x = x.permute(0, 2, 1, 3)
|
||||
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||||
return x
|
||||
|
||||
def forward(self, image_hidden_states):
|
||||
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||
return image_hidden_states
|
||||
|
||||
|
||||
class Idefics3ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
|
||||
# since Idefics3 uses the `embed_tokens` for the final prediction
|
||||
# config.text_config.tie_word_embeddings = True
|
||||
|
||||
vision_config = config.vision_config
|
||||
self.text_model = load_text_model(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
name="text_model",
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
|
||||
# The vision and connector models are not quantized.
|
||||
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||
self.vision_model = Idefics3VisionTransformer(
|
||||
prefix=(
|
||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||
),
|
||||
config=vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
config.quantize = None
|
||||
self.connector = Idefics3Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.image_token_id = config.image_token_id
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
# mask = input_ids == self.config.image_token_index
|
||||
mask = input_ids == self.config.image_token_id
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
)
|
||||
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -46,15 +46,9 @@ from text_generation_server.layers import (
|
||||
FastLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from loguru import logger
|
||||
|
||||
if SYSTEM == "cuda":
|
||||
import dropout_layer_norm
|
||||
elif SYSTEM == "rocm":
|
||||
from vllm._C import ops
|
||||
else:
|
||||
dropout_layer_norm = None
|
||||
dropout_layer_norm = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -351,94 +345,18 @@ class IdeficsRMSNorm(nn.Module):
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if SYSTEM == "ipex":
|
||||
import intel_extension_for_pytorch as ipex
|
||||
from vllm_hpu_extension.kernels import rms_norm
|
||||
|
||||
out = ipex.llm.functional.add_rms_norm(
|
||||
residual,
|
||||
hidden_states,
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon,
|
||||
residual is not None,
|
||||
)
|
||||
return out
|
||||
elif hidden_states.shape[-1] > 8192:
|
||||
orig_shape = hidden_states.shape
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
elif SYSTEM == "cuda":
|
||||
# faster post attention rms norm
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
if unwrap:
|
||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||
|
||||
return normed_hidden_states
|
||||
elif SYSTEM == "rocm":
|
||||
# We use VLLM RMSNorm kernel that can be compiled for RoCm, instead of Flash Attention ones that can not.
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
unwrap = False
|
||||
if len(hidden_states.shape) > 2:
|
||||
unwrap = True
|
||||
shape = hidden_states.shape
|
||||
hidden_states = hidden_states.reshape(-1, shape[-1])
|
||||
|
||||
out = torch.empty_like(hidden_states)
|
||||
ops.rms_norm(
|
||||
out,
|
||||
hidden_states,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
|
||||
if unwrap:
|
||||
out = out.view(*shape)
|
||||
|
||||
return out
|
||||
residual += hidden_states.view(residual.shape)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
|
||||
)
|
||||
residual = hidden_states
|
||||
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
|
||||
if len(orig_shape) == 2:
|
||||
residual = residual.unsqueeze(0)
|
||||
x = rms_norm().apply(residual, self.weight, self.variance_epsilon)
|
||||
return x.view(orig_shape), residual.view(orig_shape)
|
||||
|
||||
|
||||
# this was adapted from LlamaMLP
|
||||
|
@ -14,17 +14,25 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Llava-NeXT model."""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
unpad_image,
|
||||
)
|
||||
from optimum.habana.transformers.models import GaudiLlavaNextForConditionalGeneration
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
"""
|
||||
@ -32,7 +40,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
|
||||
Args:
|
||||
image_size (`tuple`):
|
||||
The size of the input image in the format (width, height).
|
||||
The size of the input image in the format (height, width).
|
||||
grid_pinpoints (`List`):
|
||||
A list containing possible resolutions. Each item in the list should be a tuple or list
|
||||
of the form `(height, width)`.
|
||||
@ -40,7 +48,7 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
The size of each image patch.
|
||||
|
||||
Returns:
|
||||
tuple: The shape of the image patch grid in the format (width, height).
|
||||
tuple: The shape of the image patch grid in the format (height, width).
|
||||
"""
|
||||
if not isinstance(grid_pinpoints, list):
|
||||
raise ValueError("grid_pinpoints should be a list of tuples or lists")
|
||||
@ -49,13 +57,100 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
def unpad_image(tensor, original_size):
|
||||
"""
|
||||
Unpads a PyTorch tensor of a padded and resized image.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`):
|
||||
The image tensor, assumed to be of shape (num_channels, height, width).
|
||||
original_size (`tuple`):
|
||||
The original size of the image (height, width).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The unpadded image tensor.
|
||||
"""
|
||||
original_height, original_width = original_size
|
||||
current_height, current_width = tensor.shape[1:]
|
||||
|
||||
original_aspect_ratio = original_width / original_height
|
||||
current_aspect_ratio = current_width / current_height
|
||||
|
||||
if original_aspect_ratio > current_aspect_ratio:
|
||||
scale_factor = current_width / original_width
|
||||
new_height = int(original_height * scale_factor)
|
||||
padding = (current_height - new_height) // 2
|
||||
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
||||
else:
|
||||
scale_factor = current_height / original_height
|
||||
new_width = int(original_width * scale_factor)
|
||||
padding = (current_width - new_width) // 2
|
||||
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
||||
|
||||
return unpadded_tensor
|
||||
|
||||
|
||||
# Copied from transformers.models.llava.modeling_llava.LlavaMultiModalProjector with Llava->LlavaNext
|
||||
class LlavaNextMultiModalProjector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.linear_1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.act = ACT2FN[config.projector_hidden_act]
|
||||
self.linear_2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.linear_2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
vision_config = config.vision_config
|
||||
# Instead of selecting in hidden_states[-2].
|
||||
# Instead compute only the n -2 + 1 layers and don't pool
|
||||
if config.vision_feature_layer < 0:
|
||||
vision_config.num_hidden_layers += config.vision_feature_layer + 1
|
||||
else:
|
||||
vision_config.num_hidden_layers = config.vision_feature_layer + 1
|
||||
self.vision_tower = load_vision_model(
|
||||
prefix="vision_tower" if not prefix else f"{prefix}.vision_tower",
|
||||
config=config.vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.multi_modal_projector = LlavaNextMultiModalProjector(
|
||||
prefix="multi_modal_projector", config=config, weights=weights
|
||||
)
|
||||
|
||||
self.image_newline = weights.get_tensor("image_newline")
|
||||
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.config = config
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
self.text_model = load_text_model(
|
||||
prefix="language_model" if not prefix else f"{prefix}.language_model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
input_ids: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
@ -70,148 +165,55 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
vision_feature_layer: Optional[int] = None,
|
||||
vision_feature_select_strategy: Optional[str] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
token_idx: Optional[torch.Tensor] = None,
|
||||
use_flash_attention: Optional[bool] = True,
|
||||
flash_attention_recompute: Optional[bool] = True,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
|
||||
if token_idx is not None:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
token_idx=token_idx,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
)
|
||||
|
||||
logits = outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
pixel_values=None,
|
||||
image_sizes=None,
|
||||
attention_mask=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Inherits from LlavaForConditionalGeneration: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava_next/modeling_llava_next.py#L635
|
||||
The only differences are:
|
||||
- add new args token_idx
|
||||
- add the process of merging images into inputs_embeds
|
||||
"""
|
||||
token_idx = kwargs.get("token_idx", None)
|
||||
if token_idx is None:
|
||||
return super().prepare_inputs_for_generation(
|
||||
input_ids=input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pixel_values=pixel_values,
|
||||
image_sizes=image_sizes,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
use_flash_attention = kwargs.get("use_flash_attention", True)
|
||||
flash_attention_recompute = kwargs.get("flash_attention_recompute", True)
|
||||
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
labels = kwargs.get("labels", None)
|
||||
if (
|
||||
past_key_values is None
|
||||
and pixel_values is not None
|
||||
and input_ids.shape[1] != 1
|
||||
):
|
||||
vision_feature_select_strategy = kwargs.get(
|
||||
"vision_feature_select_strategy", None
|
||||
)
|
||||
vision_feature_layer = kwargs.get("vision_feature_layer", None)
|
||||
vision_feature_select_strategy = (
|
||||
vision_feature_select_strategy
|
||||
if vision_feature_select_strategy is not None
|
||||
else self.config.vision_feature_select_strategy
|
||||
)
|
||||
vision_feature_layer = (
|
||||
vision_feature_layer
|
||||
if vision_feature_layer is not None
|
||||
else self.config.vision_feature_layer
|
||||
)
|
||||
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
# num_special_image_tokens = (input_ids == self.config.image_token_index).sum()
|
||||
# assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid"
|
||||
# 1. Extract the input embeddings
|
||||
inputs_embeds = self.get_input_embeddings()(input_ids)
|
||||
|
||||
# 2. Merge text and images
|
||||
batch_size, num_patches, num_channels, height, width = (
|
||||
pixel_values.shape
|
||||
)
|
||||
reshaped_pixel_values = pixel_values.view(
|
||||
batch_size * num_patches, num_channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(
|
||||
reshaped_pixel_values,
|
||||
output_hidden_states=True,
|
||||
use_flash_attention=use_flash_attention,
|
||||
flash_attention_recompute=flash_attention_recompute,
|
||||
num_images, num_patches, channels, height, width = pixel_values.shape
|
||||
pixel_values = pixel_values.view(
|
||||
num_images * num_patches, channels, height, width
|
||||
)
|
||||
image_features = self.vision_tower(pixel_values)
|
||||
|
||||
selected_image_feature = image_features.hidden_states[
|
||||
vision_feature_layer
|
||||
]
|
||||
# selected_image_feature = image_features.hidden_states[self.config.vision_feature_layer]
|
||||
# Already done within the clip model
|
||||
selected_image_feature = image_features.last_hidden_state
|
||||
|
||||
if vision_feature_select_strategy == "default":
|
||||
if self.config.vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
elif self.config.vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Strategy `{self.config.vision_feature_select_strategy}` is not supported/valid."
|
||||
)
|
||||
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
# split up image_features for each of the individual images
|
||||
# hence we get a list of image_features, each of shape (5, num_patches, hidden_size)
|
||||
# if we assume each image has 5 image features (base image + 4 patches)
|
||||
split_sizes = [image.shape[0] for image in pixel_values]
|
||||
split_sizes = [num_patches] * num_images
|
||||
image_features = torch.split(image_features, split_sizes, dim=0)
|
||||
|
||||
# NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
|
||||
@ -231,22 +233,19 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
"The number of patches is not consistent with the image size."
|
||||
)
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx].tolist(),
|
||||
# Dimensions are intentionally swapped to be bug-compatible with
|
||||
# upstream: https://github.com/LLaVA-VL/LLaVA-NeXT/issues/59
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_sizes[image_idx],
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
|
||||
image_feature = image_feature.view(
|
||||
num_patch_height, num_patch_width, height, width, -1
|
||||
)
|
||||
image_feature = image_feature.permute(
|
||||
4, 0, 2, 1, 3
|
||||
).contiguous()
|
||||
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
||||
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
||||
image_feature = unpad_image(
|
||||
image_feature, image_sizes[image_idx]
|
||||
)
|
||||
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
||||
image_feature = torch.cat(
|
||||
(
|
||||
image_feature,
|
||||
@ -267,76 +266,25 @@ class LlavaNextForConditionalGeneration(GaudiLlavaNextForConditionalGeneration):
|
||||
)
|
||||
new_image_features.append(image_feature)
|
||||
image_features = torch.stack(new_image_features, dim=0)
|
||||
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
inputs_embeds, image_features, input_ids
|
||||
)
|
||||
self.image_offset = (
|
||||
image_features.shape[1] - 1
|
||||
) # image_token has occupied 1 token position.
|
||||
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
|
||||
# generation with cache
|
||||
elif past_key_values is not None:
|
||||
seq_len = input_ids.shape[1]
|
||||
pad_len = seq_len - token_idx.item()
|
||||
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
|
||||
# Retrieve the first layer to inspect the logits and mask out the hidden states
|
||||
# that are set to 0
|
||||
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
|
||||
|
||||
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
|
||||
batch_index, non_attended_tokens = torch.where(
|
||||
first_layer_past_key_value.float().sum(-2) == 0
|
||||
input_ids, inputs_embeds, image_features
|
||||
)
|
||||
|
||||
# Get the target length
|
||||
past_length = first_layer_past_key_value.shape[-1]
|
||||
extended_attention_mask = torch.ones(
|
||||
(attention_mask.shape[0], past_length),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
# Filter out only the tokens that can be un-attended, this can happen
|
||||
# if one uses Llava + Fused modules where the cache on the
|
||||
# first iteration is already big enough, or if one passes custom cache
|
||||
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
|
||||
new_batch_index = batch_index[valid_indices]
|
||||
new_non_attended_tokens = non_attended_tokens[valid_indices]
|
||||
|
||||
# Zero-out the places where we don't need to attend
|
||||
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
|
||||
|
||||
attention_mask = extended_attention_mask
|
||||
attention_mask[:, -pad_len:] = 0
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
if token_idx is not None:
|
||||
position_ids = (
|
||||
torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
||||
)
|
||||
else:
|
||||
position_ids = position_ids[:, -input_ids.shape[1] :]
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
"token_idx": token_idx,
|
||||
"labels": labels,
|
||||
"use_flash_attention": use_flash_attention,
|
||||
"flash_attention_recompute": flash_attention_recompute,
|
||||
}
|
||||
)
|
||||
|
||||
return model_inputs
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
||||
|
@ -196,6 +196,9 @@ class MambaModel(nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
prefix = "backbone"
|
||||
try:
|
||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embeddings", weights)
|
||||
except RuntimeError:
|
||||
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
@ -206,6 +209,9 @@ class MambaModel(nn.Module):
|
||||
self.norm_f = FastRMSNorm.load(
|
||||
f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
try:
|
||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embeddings", weights)
|
||||
except RuntimeError:
|
||||
self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights)
|
||||
self.config = config
|
||||
|
||||
|
@ -19,7 +19,10 @@ from typing import Optional, Tuple, List
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import flash_attn_2_cuda
|
||||
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
@ -488,9 +491,14 @@ class MllamaVisionModel(nn.Module):
|
||||
aspect_ratio_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
||||
pixel_values.shape
|
||||
)
|
||||
(
|
||||
batch_size,
|
||||
num_concurrent_media,
|
||||
num_tiles,
|
||||
num_channels,
|
||||
height,
|
||||
width,
|
||||
) = pixel_values.shape
|
||||
|
||||
pixel_values = pixel_values.reshape(
|
||||
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
||||
@ -698,29 +706,24 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
# logger.info(
|
||||
# f"Q: {query_states.shape} -K {key_states.shape} - V{value_states.shape}"
|
||||
# )
|
||||
attn_output = flash_attn_2_cuda.varlen_fwd(
|
||||
query_states = query_states.unsqueeze(0).transpose(1, 2)
|
||||
key_states = key_states.unsqueeze(0).transpose(1, 2)
|
||||
value_states = value_states.unsqueeze(0).transpose(1, 2)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
None,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
None,
|
||||
None,
|
||||
None, # block_tables
|
||||
None,
|
||||
max_q,
|
||||
max_k,
|
||||
0.0,
|
||||
self.softmax_scale,
|
||||
False,
|
||||
causal, # Causal
|
||||
-1, # window_size_left,
|
||||
-1,
|
||||
0.0, # softcap
|
||||
False,
|
||||
None,
|
||||
)[0]
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
|
||||
attn_output = self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
return attn_output
|
||||
|
@ -12,7 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch OPT model."""
|
||||
"""PyTorch OPT model."""
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@ -99,7 +100,7 @@ class OPTLearnedPositionalEmbedding(nn.Module):
|
||||
self.offset = 2
|
||||
self.weight = nn.Parameter(
|
||||
weights.get_tensor(
|
||||
f"{prefix + '.' if prefix else ''}decoder.embed_positions.weight"
|
||||
f"{prefix if prefix else ''}decoder.embed_positions.weight"
|
||||
)
|
||||
)
|
||||
|
||||
@ -317,7 +318,6 @@ class OPTDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.process_group = weights.process_group
|
||||
self.hidden_size = config.hidden_size
|
||||
prefix = f"{prefix + '.' if prefix else ''}decoder.layers.{layer_id}"
|
||||
self.self_attn = OPTAttention(
|
||||
config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
@ -478,7 +478,12 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
OPTDecoderLayer(layer_id, prefix, config, weights)
|
||||
OPTDecoderLayer(
|
||||
layer_id,
|
||||
prefix=f"{prefix}decoder.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@ -755,6 +760,8 @@ class OPTModel(OPTPreTrainedModel):
|
||||
class OPTForCausalLM(OPTPreTrainedModel):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__(config)
|
||||
if not prefix and any(s.startswith("model") for s in weights.routing.keys()):
|
||||
prefix = "model"
|
||||
|
||||
self.model = OPTModel(prefix, config, weights)
|
||||
|
||||
|
@ -0,0 +1,947 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen2.5 VL model."""
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers.layernorm import FastRMSNorm
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py
|
||||
from typing import Union
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.image_utils import ImageInput, VideoInput
|
||||
from transformers.processing_utils import (
|
||||
ProcessingKwargs,
|
||||
ProcessorMixin,
|
||||
Unpack,
|
||||
VideosKwargs,
|
||||
)
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
|
||||
|
||||
class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False):
|
||||
fps: Union[List[float], float]
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False):
|
||||
videos_kwargs: Qwen2_5_VLVideosProcessorKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
},
|
||||
"videos_kwargs": {"fps": 2.0},
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5_VLProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor.
|
||||
[`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the
|
||||
[`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information.
|
||||
Args:
|
||||
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
||||
The image processor is a required input.
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The tokenizer is a required input.
|
||||
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
||||
in a chat into a tokenizable string.
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
valid_kwargs = ["chat_template"]
|
||||
|
||||
image_processor_class = "AutoImageProcessor"
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
||||
):
|
||||
self.image_token = (
|
||||
"<|image_pad|>"
|
||||
if not hasattr(tokenizer, "image_token")
|
||||
else tokenizer.image_token
|
||||
)
|
||||
self.video_token = (
|
||||
"<|video_pad|>"
|
||||
if not hasattr(tokenizer, "video_token")
|
||||
else tokenizer.video_token
|
||||
)
|
||||
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput = None,
|
||||
text: Union[
|
||||
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
||||
] = None,
|
||||
videos: VideoInput = None,
|
||||
**kwargs: Unpack[Qwen2_5_VLProcessorKwargs],
|
||||
) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
||||
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
||||
|
||||
Args:
|
||||
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
|
||||
tensor. Both channels-first and channels-last formats are supported.
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
|
||||
The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
|
||||
tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
|
||||
return_tensors (`str` or [`~utils.TensorType`], *optional*):
|
||||
If set, will return tensors of a particular framework. Acceptable values are:
|
||||
- `'tf'`: Return TensorFlow `tf.constant` objects.
|
||||
- `'pt'`: Return PyTorch `torch.Tensor` objects.
|
||||
- `'np'`: Return NumPy `np.ndarray` objects.
|
||||
- `'jax'`: Return JAX `jnp.ndarray` objects.
|
||||
|
||||
Returns:
|
||||
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
|
||||
|
||||
- **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
|
||||
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
|
||||
`return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
|
||||
`None`).
|
||||
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
|
||||
- **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
|
||||
- **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
|
||||
- **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
|
||||
- **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`.
|
||||
"""
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen2_5_VLProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images=images, videos=None, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_grid_thw = None
|
||||
|
||||
if videos is not None:
|
||||
videos_inputs = self.image_processor(
|
||||
images=None, videos=videos, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
video_grid_thw = videos_inputs["video_grid_thw"]
|
||||
|
||||
fps = output_kwargs["videos_kwargs"].pop("fps", 2.0)
|
||||
if isinstance(fps, (int, float)):
|
||||
second_per_grid_ts = [
|
||||
self.image_processor.temporal_patch_size / fps
|
||||
] * len(video_grid_thw)
|
||||
elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw):
|
||||
second_per_grid_ts = [
|
||||
self.image_processor.temporal_patch_size / tmp for tmp in fps
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number."
|
||||
)
|
||||
videos_inputs.update({"second_per_grid_ts": second_per_grid_ts})
|
||||
|
||||
else:
|
||||
videos_inputs = {}
|
||||
video_grid_thw = None
|
||||
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.image_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.image_token,
|
||||
"<|placeholder|>"
|
||||
* (image_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
||||
|
||||
if video_grid_thw is not None:
|
||||
merge_length = self.image_processor.merge_size**2
|
||||
index = 0
|
||||
for i in range(len(text)):
|
||||
while self.video_token in text[i]:
|
||||
text[i] = text[i].replace(
|
||||
self.video_token,
|
||||
"<|placeholder|>"
|
||||
* (video_grid_thw[index].prod() // merge_length),
|
||||
1,
|
||||
)
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", self.video_token)
|
||||
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs})
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def post_process_image_text_to_text(self, generated_outputs):
|
||||
"""
|
||||
Post-process the output of the model to decode the text.
|
||||
|
||||
Args:
|
||||
generated_outputs (`torch.Tensor` or `np.ndarray`):
|
||||
The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
|
||||
or `(sequence_length,)`.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The decoded text.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(
|
||||
generated_outputs,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
image_processor_input_names = self.image_processor.model_input_names
|
||||
names_from_processor = list(
|
||||
dict.fromkeys(tokenizer_input_names + image_processor_input_names)
|
||||
)
|
||||
return names_from_processor + ["second_per_grid_ts"]
|
||||
|
||||
|
||||
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py
|
||||
class Qwen2_5_VLVisionConfig(PretrainedConfig):
|
||||
model_type = "qwen2_5_vl"
|
||||
base_config_key = "vision_config"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth=32,
|
||||
hidden_size=3584,
|
||||
hidden_act="silu",
|
||||
intermediate_size=3420,
|
||||
num_heads=16,
|
||||
in_channels=3,
|
||||
patch_size=14,
|
||||
spatial_merge_size=2,
|
||||
spatial_patch_size=14,
|
||||
temporal_patch_size=2,
|
||||
tokens_per_second=4,
|
||||
window_size=112,
|
||||
out_hidden_size=3584,
|
||||
fullatt_block_indexes=[7, 15, 23, 31],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.depth = depth
|
||||
self.hidden_size = hidden_size
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_heads = num_heads
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.spatial_patch_size = spatial_patch_size
|
||||
self.spatial_merge_size = spatial_merge_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.tokens_per_second = tokens_per_second
|
||||
self.window_size = window_size
|
||||
self.fullatt_block_indexes = fullatt_block_indexes
|
||||
self.out_hidden_size = out_hidden_size
|
||||
|
||||
|
||||
class Qwen2_5_VLConfig(PretrainedConfig):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=152064,
|
||||
hidden_size=8192,
|
||||
intermediate_size=29568,
|
||||
num_hidden_layers=80,
|
||||
num_attention_heads=64,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-05,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1000000.0,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=80,
|
||||
attention_dropout=0.0,
|
||||
vision_config=None,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
if vision_config is not None:
|
||||
self.vision_config = Qwen2_5_VLVisionConfig(**vision_config)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.attention_dropout = attention_dropout
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
||||
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
||||
# TODO: @raushan update config in the hub
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
if self.rope_scaling["type"] == "mrope":
|
||||
self.rope_scaling["type"] = "default"
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2_5VLAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size // weights.process_group.size()
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.num_heads = config.num_heads // weights.process_group.size()
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_heads,
|
||||
)
|
||||
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||
|
||||
self.proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
qkv = self.qkv(hidden_state)
|
||||
query, key, value = qkv.split(
|
||||
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||
)
|
||||
|
||||
# reshape the query, key, and value tensors
|
||||
_shape = (
|
||||
hidden_state.shape[0],
|
||||
self.num_heads,
|
||||
self.embed_dim // self.num_heads,
|
||||
)
|
||||
query = query.view(*_shape)
|
||||
key = key.view(*_shape)
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2_5VLVisionMLP(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.up = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.up_proj", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.gate = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.gate_proj", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.down = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.down_proj", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
gate_states = self.gate(hidden_states)
|
||||
up_states = self.up(hidden_states)
|
||||
activated_states = self.activation_fn(gate_states) * up_states
|
||||
down_states = self.down(activated_states)
|
||||
return down_states
|
||||
|
||||
|
||||
class Qwen2_5VLVisionBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.attn = Qwen2_5VLAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.norm1 = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm1",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.norm2 = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.norm2",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.mlp = Qwen2_5VLVisionMLP(
|
||||
prefix=f"{prefix}.mlp",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
norm1_out, _ = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
hidden_states = hidden_states + attn_out
|
||||
norm2_out, _ = self.norm2(hidden_states)
|
||||
mlp_out = self.mlp(norm2_out)
|
||||
hidden_states = hidden_states + mlp_out
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2_5VLPatchMerger(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
||||
self.patch_merger_ln_q = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.ln_q",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2_5VisionModel(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_channels=config.in_channels,
|
||||
out_channels=config.hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||
)
|
||||
head_dim = config.hidden_size // config.num_heads
|
||||
|
||||
theta = 10000.0
|
||||
dim = head_dim // 2
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2_5VLVisionBlock(
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for i in range(config.depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2_5VLPatchMerger(
|
||||
prefix=f"{prefix}.merger",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
self.embed_dim = config.hidden_size
|
||||
self.window_size = config.window_size
|
||||
self.patch_size = config.patch_size
|
||||
self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
|
||||
self.fullatt_block_indexes = config.fullatt_block_indexes
|
||||
|
||||
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, hidden_size = hidden_state.shape
|
||||
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||
return hidden_state
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (
|
||||
self.window_size // self.spatial_merge_size // self.patch_size
|
||||
)
|
||||
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
grid_h // self.spatial_merge_size,
|
||||
grid_w // self.spatial_merge_size,
|
||||
)
|
||||
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
||||
grid_t, llm_grid_h, llm_grid_w
|
||||
)
|
||||
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
||||
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
||||
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
||||
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
||||
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
|
||||
index_padded = index_padded.reshape(
|
||||
grid_t,
|
||||
num_windows_h,
|
||||
vit_merger_window_size,
|
||||
num_windows_w,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
|
||||
grid_t,
|
||||
num_windows_h * num_windows_w,
|
||||
vit_merger_window_size,
|
||||
vit_merger_window_size,
|
||||
)
|
||||
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
|
||||
index_padded = index_padded.reshape(-1)
|
||||
index_new = index_padded[index_padded != -100]
|
||||
window_index.append(index_new + window_index_id)
|
||||
cu_seqlens_tmp = (
|
||||
seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
|
||||
)
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# reshape the input tensor for processing
|
||||
shape = (
|
||||
-1,
|
||||
self.in_channels,
|
||||
self.temporal_patch_size,
|
||||
self.spatial_patch_size,
|
||||
self.spatial_patch_size,
|
||||
)
|
||||
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||
# TODO: revisit to see if we can avoid some of these reshapes
|
||||
|
||||
# find the position ids for the input tensor based on the grid_thw
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
|
||||
# apply the positional embeddings to the position ids
|
||||
seq = torch.arange(
|
||||
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
||||
seq_len = hidden_states.shape[0]
|
||||
patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
||||
og_shape = (seq_len, -1)
|
||||
|
||||
hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
|
||||
og_shape
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
|
||||
og_shape
|
||||
)
|
||||
|
||||
rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
|
||||
|
||||
cu_window_seqlens = torch.tensor(
|
||||
cu_window_seqlens,
|
||||
device=hidden_states.device,
|
||||
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
|
||||
)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
|
||||
# create a cu_seqlens tensor to be used in the attention mask
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
|
||||
# iterately apply the blocks to the hidden states
|
||||
for layer_num, block in enumerate(self.blocks):
|
||||
# NOTE: qwen2_5_vl.py has a concept of full attention blocks
|
||||
# that are applied at specific layers.
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
|
||||
hidden_states = block(
|
||||
hidden_states, cu_seqlens_now, rotary_pos_emb, max_seqlen
|
||||
)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2_5VLForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||
# returns rope_scaling.type == "default" for Qwen2_5-VL model at the moment
|
||||
if (
|
||||
hasattr(config, "rope_scaling")
|
||||
and config.rope_scaling is not None
|
||||
and config.rope_scaling.get("type", None) == "default"
|
||||
):
|
||||
config.rope_scaling.update({"rope_type": "mrope"})
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_start_token_id = config.vision_start_token_id
|
||||
self.vision_end_token_id = config.vision_end_token_id
|
||||
self.image_token_id = config.image_token_id
|
||||
self.video_token_id = config.video_token_id
|
||||
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.visual = Qwen2_5VisionModel(
|
||||
prefix="visual", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
self.device = weights.device
|
||||
|
||||
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||
# modified to first find segments then initialize position ids for each segment
|
||||
# Steps:
|
||||
# locate all vision and text segments
|
||||
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||
# create position ids for each vision segment based on the image grid
|
||||
# create position ids for each text segment
|
||||
# combine all the position ids
|
||||
# the final segment is the difference between the last vision segment and the end of the input
|
||||
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||
def get_position_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if image_grid_thw is None:
|
||||
return (
|
||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, 3)
|
||||
)
|
||||
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
vision_start_token_id = self.vision_start_token_id
|
||||
vision_end_token_id = self.vision_end_token_id
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
input_ids_len = input_ids.shape[0]
|
||||
|
||||
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||
prev_vision_end = torch.cat(
|
||||
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||
)
|
||||
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||
vision_widths_max = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||
]
|
||||
)
|
||||
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||
|
||||
# create position ids for each vision segment based on the image grid
|
||||
llm_pos_ids_list = []
|
||||
for i, _ in enumerate(vision_segments):
|
||||
t, h, w = (
|
||||
image_grid_thw[i][0],
|
||||
image_grid_thw[i][1] // spatial_merge_size,
|
||||
image_grid_thw[i][2] // spatial_merge_size,
|
||||
)
|
||||
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||
|
||||
# offset by the position of the last vision segment
|
||||
im = image_position_ids + vision_segment_lengths[i]
|
||||
llm_pos_ids_list.append(im)
|
||||
|
||||
# create position ids for each text segment
|
||||
text_ranges = [
|
||||
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||
+ text_segment_lengths[i]
|
||||
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||
]
|
||||
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
# import ipdb
|
||||
|
||||
# ipdb.set_trace()
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||
full_llm_pos_ids_list.append(m + max_s)
|
||||
|
||||
position_ids = (
|
||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
# Unused in this model
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -0,0 +1,522 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Qwen2 VL model."""
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
|
||||
from habana_frameworks.torch.hpex.kernels import FusedSDPA
|
||||
from vllm_hpu_extension.utils import ModuleFusedSDPA
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
import torch.nn.functional as F
|
||||
|
||||
from text_generation_server.layers.layernorm import FastLayerNorm, FastRMSNorm
|
||||
from text_generation_server.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
)
|
||||
from text_generation_server.layers.attention import (
|
||||
Seqlen,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
|
||||
Qwen2Model,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
output = output.to(orig_dtype)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VLAttention(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.embed_dim // weights.process_group.size()
|
||||
self.head_dim = config.hidden_size // config.num_heads
|
||||
self.num_heads = config.num_heads // weights.process_group.size()
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_qkv(
|
||||
config,
|
||||
prefix=f"{prefix}.qkv",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
num_heads=self.num_heads,
|
||||
num_key_value_heads=self.num_heads,
|
||||
)
|
||||
self.qkv.linear.bias = weights.get_sharded(f"{prefix}.qkv.bias", dim=0)
|
||||
self.proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.proj",
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.softmax_scale = 1.0 / np.sqrt(self.embed_dim // self.num_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_state: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: int,
|
||||
) -> torch.Tensor:
|
||||
# apply the qkv linear layer to the hidden state
|
||||
qkv = self.qkv(hidden_state)
|
||||
query, key, value = qkv.split(
|
||||
[self.embed_dim, self.embed_dim, self.embed_dim], dim=1
|
||||
)
|
||||
|
||||
# reshape the query, key, and value tensors
|
||||
_shape = (
|
||||
hidden_state.shape[0],
|
||||
self.num_heads,
|
||||
self.embed_dim // self.num_heads,
|
||||
)
|
||||
query = query.view(*_shape)
|
||||
key = key.view(*_shape)
|
||||
value = value.view(*_shape)
|
||||
|
||||
# apply rotary positional embeddings
|
||||
query = apply_rotary_pos_emb_vision(query.unsqueeze(0), rotary_pos_emb).squeeze(
|
||||
0
|
||||
)
|
||||
key = apply_rotary_pos_emb_vision(key.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
||||
|
||||
# calc maximum sequence length for any batch
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
causal = False
|
||||
|
||||
# execute sdpa
|
||||
query = query.unsqueeze(0).transpose(1, 2)
|
||||
key = key.unsqueeze(0).transpose(1, 2)
|
||||
value = value.unsqueeze(0).transpose(1, 2)
|
||||
fsdpa_op = ModuleFusedSDPA(FusedSDPA)
|
||||
attn_output = fsdpa_op(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal,
|
||||
scale=None,
|
||||
softmax_mode="None",
|
||||
recompute_mode=None,
|
||||
valid_sequence_lengths=None,
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2).squeeze(0).contiguous()
|
||||
# reshape output to original dimensions
|
||||
attn_output = attn_output.reshape(hidden_state.shape[0], -1)
|
||||
attn_output = self.proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class Qwen2VLVisionMLP(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLVisionBlock(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.attn = Qwen2VLAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.norm1 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.norm1",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.norm2 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.norm2",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.mlp = Qwen2VLVisionMLP(
|
||||
prefix=f"{prefix}.mlp",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen
|
||||
) -> torch.Tensor:
|
||||
norm1_out, residual = self.norm1(hidden_states)
|
||||
attn_out = self.attn(norm1_out, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
hidden_states = attn_out + residual
|
||||
norm2_out, residual = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(norm2_out)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLPatchMerger(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_size = config.embed_dim * (config.spatial_merge_size**2)
|
||||
self.patch_merger_ln_q = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_q",
|
||||
weights=weights,
|
||||
eps=1e-6,
|
||||
)
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states, _ = self.patch_merger_ln_q(hidden_states)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = F.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VisionModel(nn.Module):
|
||||
def __init__(self, *, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.spatial_merge_size = config.spatial_merge_size
|
||||
kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_channels=config.in_chans,
|
||||
out_channels=config.embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
|
||||
)
|
||||
head_dim = config.embed_dim // config.num_heads
|
||||
# TODO: replace with static positional embeddings once implemented
|
||||
theta = 10000.0
|
||||
dim = head_dim // 2
|
||||
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen2VLVisionBlock(
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
for i in range(config.depth)
|
||||
]
|
||||
)
|
||||
self.merger = Qwen2VLPatchMerger(
|
||||
prefix=f"{prefix}.merger",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
self.temporal_patch_size = config.temporal_patch_size
|
||||
self.spatial_patch_size = config.spatial_patch_size
|
||||
self.in_channels = config.in_channels
|
||||
self.embed_dim = config.embed_dim
|
||||
|
||||
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, _, hidden_size = hidden_state.shape
|
||||
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
||||
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
||||
return hidden_state
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
grid_thw: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# reshape the input tensor for processing
|
||||
shape = (
|
||||
-1,
|
||||
self.in_channels,
|
||||
self.temporal_patch_size,
|
||||
self.spatial_patch_size,
|
||||
self.spatial_patch_size,
|
||||
)
|
||||
pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
|
||||
hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
|
||||
# TODO: revisit to see if we can avoid some of these reshapes
|
||||
|
||||
# find the position ids for the input tensor based on the grid_thw
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
||||
hpos_ids = hpos_ids.flatten()
|
||||
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
wpos_ids = wpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
w // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
|
||||
# apply the positional embeddings to the position ids
|
||||
seq = torch.arange(
|
||||
max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
||||
)
|
||||
rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
|
||||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
||||
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)
|
||||
|
||||
# create a cu_seqlens tensor to be used in the attention mask
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1])
|
||||
# iterately apply the blocks to the hidden states
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, cu_seqlens, rotary_pos_emb, max_seqlen)
|
||||
|
||||
# apply the final patch merger to the hidden states
|
||||
hidden_states = self.merger(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
# set rope_scaling.type == "mrope" since AutoConfig.from_pretrained incorrectly
|
||||
# returns rope_scaling.type == "default" for Qwen2-VL model at the moment
|
||||
if (
|
||||
hasattr(config, "rope_scaling")
|
||||
and config.rope_scaling is not None
|
||||
and config.rope_scaling.get("type", None) == "default"
|
||||
):
|
||||
config.rope_scaling.update({"rope_type": "mrope"})
|
||||
self.hidden_size = config.hidden_size
|
||||
self.vision_start_token_id = config.vision_start_token_id
|
||||
self.vision_end_token_id = config.vision_end_token_id
|
||||
self.image_token_id = config.image_token_id
|
||||
self.video_token_id = config.video_token_id
|
||||
self.spatial_merge_size = config.vision_config.spatial_merge_size
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.visual = Qwen2VisionModel(
|
||||
prefix="visual", config=config.vision_config, weights=weights
|
||||
)
|
||||
self.text_model = Qwen2Model(prefix=None, config=config, weights=weights)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
suffix = "lm_head"
|
||||
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
||||
weights=weights,
|
||||
)
|
||||
self.norm = FastRMSNorm.load(
|
||||
prefix="model.norm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
self.device = weights.device
|
||||
|
||||
# based on https://github.com/huggingface/transformers/blob/e284c7e954abe12c34b50461c17f8115a0afe115/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1391
|
||||
# modified to first find segments then initialize position ids for each segment
|
||||
# Steps:
|
||||
# locate all vision and text segments
|
||||
# calculate `vision_segment_lengths` for each vision segment to be use as offset
|
||||
# calculate `text_segment_lengths` for each text segment to be used as offset
|
||||
# create position ids for each vision segment based on the image grid
|
||||
# create position ids for each text segment
|
||||
# combine all the position ids
|
||||
# the final segment is the difference between the last vision segment and the end of the input
|
||||
# combine all the position ids and reshape to (3, input_ids_len) then swap dimensions to (input_ids_len, 3)
|
||||
def get_position_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_grid_thw: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if image_grid_thw is None:
|
||||
return (
|
||||
torch.arange(input_ids.shape[0], device=input_ids.device)
|
||||
.unsqueeze(1)
|
||||
.repeat(1, 3)
|
||||
)
|
||||
|
||||
spatial_merge_size = self.spatial_merge_size
|
||||
vision_start_token_id = self.vision_start_token_id
|
||||
vision_end_token_id = self.vision_end_token_id
|
||||
device = input_ids.device
|
||||
dtype = input_ids.dtype
|
||||
input_ids_len = input_ids.shape[0]
|
||||
|
||||
vision_starts = torch.where(input_ids == vision_start_token_id)[0]
|
||||
vision_ends = torch.where(input_ids == vision_end_token_id)[0]
|
||||
vision_segments = torch.stack((vision_starts, vision_ends), dim=1)
|
||||
prev_vision_end = torch.cat(
|
||||
[torch.zeros(1, device=vision_ends.device, dtype=dtype), vision_ends[:-1]]
|
||||
)
|
||||
text_lengths_between_vision = vision_segments[:, 0] - prev_vision_end + 1
|
||||
vision_widths_max = torch.cat(
|
||||
[
|
||||
torch.zeros(1, device=image_grid_thw.device, dtype=dtype),
|
||||
image_grid_thw[:-1, 2] // spatial_merge_size,
|
||||
]
|
||||
)
|
||||
vision_segment_lengths = vision_widths_max + text_lengths_between_vision
|
||||
vision_segment_lengths = vision_segment_lengths.cumsum(dim=0)
|
||||
text_segment_lengths = vision_segment_lengths - text_lengths_between_vision
|
||||
|
||||
# create position ids for each vision segment based on the image grid
|
||||
llm_pos_ids_list = []
|
||||
for i, _ in enumerate(vision_segments):
|
||||
t, h, w = (
|
||||
image_grid_thw[i][0],
|
||||
image_grid_thw[i][1] // spatial_merge_size,
|
||||
image_grid_thw[i][2] // spatial_merge_size,
|
||||
)
|
||||
t_indices = torch.arange(t, device=device).repeat_interleave(h * w)
|
||||
h_indices = torch.arange(h, device=device).repeat_interleave(w).repeat(t)
|
||||
w_indices = torch.arange(w, device=device).repeat(t * h)
|
||||
image_position_ids = torch.stack([t_indices, h_indices, w_indices], dim=0)
|
||||
|
||||
# offset by the position of the last vision segment
|
||||
im = image_position_ids + vision_segment_lengths[i]
|
||||
llm_pos_ids_list.append(im)
|
||||
|
||||
# create position ids for each text segment
|
||||
text_ranges = [
|
||||
torch.arange(seq_len, device=device).view(1, -1).expand(3, -1)
|
||||
+ text_segment_lengths[i]
|
||||
for i, seq_len in enumerate(text_lengths_between_vision)
|
||||
]
|
||||
|
||||
full_llm_pos_ids_list = [
|
||||
item for sublist in zip(text_ranges, llm_pos_ids_list) for item in sublist
|
||||
]
|
||||
max_s = full_llm_pos_ids_list[-1].max() + 1
|
||||
final_text_len = input_ids_len - vision_ends[-1]
|
||||
if final_text_len > 0:
|
||||
m = torch.arange(final_text_len, device=device).view(1, -1).expand(3, -1)
|
||||
full_llm_pos_ids_list.append(m + max_s)
|
||||
|
||||
position_ids = (
|
||||
torch.cat(full_llm_pos_ids_list, dim=1).reshape(3, -1).transpose(0, 1)
|
||||
)
|
||||
return position_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# apply the visual model to the pixel values if they are provided
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
if pixel_values is not None:
|
||||
image_embeds = self.visual(
|
||||
pixel_values, grid_thw=image_grid_thw
|
||||
).squeeze(0)
|
||||
inputs_embeds[input_ids == self.image_token_id] = image_embeds
|
||||
|
||||
hidden_states = self.text_model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
@ -17,6 +17,12 @@ def load_text_model(prefix, config, weights, name=None):
|
||||
)
|
||||
|
||||
return FlashGemmaForCausalLM(prefix, config, weights, causal=False)
|
||||
elif config.model_type == "gemma2":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma2_modeling import (
|
||||
FlashGemma2ForCausalLM,
|
||||
)
|
||||
|
||||
return FlashGemma2ForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "paligemma":
|
||||
from text_generation_server.models.custom_modeling.flash_gemma_modeling import (
|
||||
FlashGemmaForCausalLM,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,53 +1,31 @@
|
||||
import torch
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.log import log_master
|
||||
|
||||
REQUEST_LOGPROBS = os.getenv("REQUEST_LOGPROBS", "0").lower() in {"1", "true"}
|
||||
ATTENTION = os.getenv("ATTENTION", "default")
|
||||
# default_prefix_caching = "1" if ATTENTION in {"flashinfer", "flashdecoding"} else "0"
|
||||
PREFIX_CACHING = os.getenv("PREFIX_CACHING", "0").lower() in {
|
||||
"1",
|
||||
"true",
|
||||
}
|
||||
PREFILL_CHUNKING = os.getenv("PREFILL_CHUNKING", "0").lower() in {"1", "true"}
|
||||
log_master(logger.info, f"Using prefix caching = {PREFIX_CACHING}")
|
||||
_expected = {"paged", "flashdecoding", "flashinfer", "default"}
|
||||
_expected = {"paged", "default"}
|
||||
assert (
|
||||
ATTENTION in _expected
|
||||
), f"Attention is not valid {ATTENTION}, expected {_expected}"
|
||||
log_master(logger.info, f"Using Attention = {ATTENTION}")
|
||||
|
||||
if PREFIX_CACHING and ATTENTION not in {"flashinfer", "flashdecoding"}:
|
||||
raise RuntimeError("Prefix caching is only supported with flashinfer")
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
TGI_WIGGLE_ROOM = float(os.getenv("TGI_WIGGLE_ROOM", "0.90"))
|
||||
assert TGI_WIGGLE_ROOM > 0
|
||||
assert TGI_WIGGLE_ROOM < 1
|
||||
|
||||
# This is overridden by the cli
|
||||
BLOCK_SIZE: int
|
||||
if ATTENTION == "flashdecoding":
|
||||
BLOCK_SIZE = 256
|
||||
elif ATTENTION == "flashinfer":
|
||||
BLOCK_SIZE = 1
|
||||
else:
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
# This is overridden by the cli
|
||||
cuda_graphs = os.getenv("CUDA_GRAPHS")
|
||||
if cuda_graphs is not None:
|
||||
try:
|
||||
cuda_graphs = [int(item) for item in cuda_graphs.split(",")]
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Could not parse cuda graphs {cuda_graphs}, expected comma separated list for batch sizes to run on: {e}"
|
||||
)
|
||||
else:
|
||||
cuda_graphs = None
|
||||
BLOCK_SIZE = 128
|
||||
|
||||
CUDA_GRAPHS = cuda_graphs
|
||||
|
||||
# This is overridden at model loading.
|
||||
global MODEL_ID
|
||||
|
@ -34,9 +34,6 @@ from text_generation_server.utils import (
|
||||
)
|
||||
from text_generation_server.utils.quantization import get_loader
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
@ -596,22 +593,8 @@ class IdeficsCausalLM(Model):
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
# 9b seems to work correctly enough in float16, but 80b seems
|
||||
# to be really saturating for f16.
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
device = torch.device("hpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
self.device, self.dtype = device, dtype
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
|
@ -11,10 +11,6 @@ from transformers import (
|
||||
)
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLMBatch, VlmCausalLM
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
block_tables_to_ragged,
|
||||
)
|
||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||
from text_generation_server.layers.attention import Seqlen
|
||||
|
||||
|
||||
@ -254,34 +250,7 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
# Try to find an associated cuda graph
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
if (
|
||||
cu_seqlen_prefill is not None
|
||||
or cuda_graph is None
|
||||
# Only run cuda graphs when there's no images.
|
||||
or batch.cross_attention_states is not None
|
||||
):
|
||||
input_lengths = input_lengths + prefix_lens_tensor
|
||||
if PREFIX_CACHING:
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
prefix_lens_tensor=prefix_lens_tensor,
|
||||
):
|
||||
max_k = (input_lengths + prefix_lens_tensor).max().item()
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
@ -321,37 +290,3 @@ class MllamaCausalLM(VlmCausalLM):
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
return logits, speculative_logits
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[0]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
prefix_lens=batch.prefix_lens,
|
||||
)
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = (
|
||||
input_lengths + prefix_lens_tensor
|
||||
)
|
||||
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
speculative_logits = (
|
||||
cuda_graph["speculative_logits"][:bs]
|
||||
if cuda_graph["speculative_logits"] is not None
|
||||
else None
|
||||
)
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, speculative_logits
|
||||
|
@ -33,6 +33,7 @@ class Model(ABC):
|
||||
sliding_window: Optional[int] = None,
|
||||
speculate: Optional[int] = None,
|
||||
adapter_id: str = BASE_MODEL_ADAPTER_ID,
|
||||
support_chunking: bool = False,
|
||||
):
|
||||
self.model_id = model_id
|
||||
self.model = model.eval()
|
||||
|
@ -10,7 +10,6 @@ from transformers import (
|
||||
AutoConfig,
|
||||
)
|
||||
from typing import Optional, Tuple, List, Type, Dict
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
@ -555,20 +554,9 @@ class Seq2SeqLM(Model):
|
||||
):
|
||||
self.quantize = quantize
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
elif SYSTEM == "ipex":
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device(f"xpu:{rank}")
|
||||
dtype = default_dtype if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
# Float16 doesn't exist on target.
|
||||
|
||||
device = torch.device("hpu")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
config = config_class.from_pretrained(
|
||||
model_id,
|
||||
|
@ -1,15 +1,13 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
from torch.distributed import ProcessGroup
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
MEMORY_FRACTION = float(os.getenv("HPU_MEMORY_FRACTION", "0.8"))
|
||||
|
||||
|
||||
class FakeBarrier:
|
||||
@ -17,10 +15,11 @@ class FakeBarrier:
|
||||
pass
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
class FakeGroup(ProcessGroup):
|
||||
def __init__(self, rank, size):
|
||||
self._rank = rank
|
||||
self._size = size
|
||||
super().__init__(rank, size)
|
||||
|
||||
def allreduce(self, *args, **kwargs):
|
||||
return FakeBarrier()
|
||||
@ -42,42 +41,11 @@ class FakeGroup:
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def _get_backend_name(self):
|
||||
return "fake"
|
||||
|
||||
|
||||
def initialize_torch_distributed():
|
||||
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
options = None
|
||||
if torch.cuda.is_available():
|
||||
from torch.distributed import ProcessGroupNCCL
|
||||
|
||||
# Set the device id.
|
||||
assert WORLD_SIZE <= torch.cuda.device_count(), "Each process is one gpu"
|
||||
device = RANK % torch.cuda.device_count()
|
||||
torch.cuda.set_device(device)
|
||||
torch.cuda.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "nccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
elif torch.hpu.is_available():
|
||||
backend = "hccl"
|
||||
n_hpus = torch.hpu.device_count()
|
||||
if world_size > n_hpus:
|
||||
raise ValueError(
|
||||
f"WORLD_SIZE ({world_size}) is higher than the number of available HPUs ({n_hpus})."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
|
||||
backend = "ccl"
|
||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
||||
except ImportError:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
||||
if WORLD_SIZE == 1:
|
||||
return FakeGroup(RANK, WORLD_SIZE), RANK, WORLD_SIZE
|
||||
else:
|
||||
@ -87,11 +55,10 @@ def initialize_torch_distributed():
|
||||
if not torch.distributed.is_initialized():
|
||||
# Call the init process.
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
backend="hccl",
|
||||
world_size=WORLD_SIZE,
|
||||
rank=RANK,
|
||||
timeout=timedelta(seconds=60),
|
||||
pg_options=options,
|
||||
timeout=timedelta(seconds=120),
|
||||
)
|
||||
else:
|
||||
logger.warning("torch.distributed is already initialized.")
|
||||
|
@ -1,75 +1,28 @@
|
||||
import torch
|
||||
from loguru import logger
|
||||
import os
|
||||
|
||||
|
||||
import importlib.util
|
||||
def get_hpu_free_memory(device, memory_fraction):
|
||||
from habana_frameworks.torch.hpu import memory_stats
|
||||
|
||||
|
||||
def is_ipex_available():
|
||||
return importlib.util.find_spec("intel_extension_for_pytorch") is not None
|
||||
|
||||
|
||||
def get_cuda_free_memory(device, memory_fraction):
|
||||
total_free_memory, _ = torch.cuda.mem_get_info(device)
|
||||
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
|
||||
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
|
||||
return free_memory
|
||||
|
||||
|
||||
def get_xpu_free_memory(device, memory_fraction):
|
||||
total_memory = torch.xpu.get_device_properties(device).total_memory
|
||||
device_id = device.index
|
||||
memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
|
||||
mem_stats = memory_stats(device_id)
|
||||
logger.info(f"mem_stats: {mem_stats}")
|
||||
total_free_memory = mem_stats["Limit"] - mem_stats["MaxInUse"]
|
||||
free_memory = max(
|
||||
0,
|
||||
int(
|
||||
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
|
||||
),
|
||||
0, int(total_free_memory - (1 - memory_fraction) * mem_stats["Limit"])
|
||||
)
|
||||
return free_memory
|
||||
|
||||
|
||||
def get_cpu_free_memory(device, memory_fraction):
|
||||
import psutil
|
||||
from text_generation_server.utils.dist import WORLD_SIZE
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
||||
return free_memory
|
||||
def synchronize_hpu(device):
|
||||
torch.hpu.synchronize()
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
SYSTEM = None
|
||||
if torch.version.hip is not None:
|
||||
SYSTEM = "rocm"
|
||||
empty_cache = torch.cuda.empty_cache
|
||||
synchronize = torch.cuda.synchronize
|
||||
get_free_memory = get_cuda_free_memory
|
||||
elif torch.version.cuda is not None and torch.cuda.is_available():
|
||||
SYSTEM = "cuda"
|
||||
empty_cache = torch.cuda.empty_cache
|
||||
synchronize = torch.cuda.synchronize
|
||||
get_free_memory = get_cuda_free_memory
|
||||
elif is_ipex_available():
|
||||
SYSTEM = "ipex"
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
|
||||
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
empty_cache = torch.xpu.empty_cache
|
||||
synchronize = torch.xpu.synchronize
|
||||
get_free_memory = get_xpu_free_memory
|
||||
else:
|
||||
empty_cache = noop
|
||||
synchronize = noop
|
||||
get_free_memory = get_cpu_free_memory
|
||||
else:
|
||||
SYSTEM = "cpu"
|
||||
|
||||
empty_cache = noop
|
||||
synchronize = noop
|
||||
get_free_memory = get_cpu_free_memory
|
||||
logger.info(f"Detected system {SYSTEM}")
|
||||
empty_cache = noop
|
||||
synchronize = synchronize_hpu
|
||||
get_free_memory = get_hpu_free_memory
|
||||
|
@ -0,0 +1,22 @@
|
||||
import importlib
|
||||
|
||||
from loguru import logger
|
||||
from hf_kernels import load_kernel as hf_load_kernel
|
||||
|
||||
from text_generation_server.utils.log import log_once
|
||||
|
||||
|
||||
def load_kernel(*, module: str, repo_id: str):
|
||||
"""
|
||||
Load a kernel. First try to load it as the given module (e.g. for
|
||||
local development), falling back to a locked Hub kernel.
|
||||
"""
|
||||
try:
|
||||
m = importlib.import_module(module)
|
||||
log_once(logger.info, f"Using local module for `{module}`")
|
||||
return m
|
||||
except ModuleNotFoundError:
|
||||
return hf_load_kernel(repo_id=repo_id)
|
||||
|
||||
|
||||
__all__ = ["load_kernel"]
|
@ -7,8 +7,6 @@ from typing import Dict, List, Optional, Union, Type
|
||||
from safetensors import safe_open
|
||||
from dataclasses import dataclass
|
||||
|
||||
from text_generation_server.utils.import_utils import SYSTEM
|
||||
|
||||
|
||||
class WeightsLoader(ABC):
|
||||
"""
|
||||
@ -88,11 +86,8 @@ class UnquantizedWeight(Weight):
|
||||
weight: torch.Tensor
|
||||
|
||||
def get_linear(self, bias: torch.Tensor):
|
||||
from text_generation_server.layers.linear import FastLinear, FastLinearROCm
|
||||
from text_generation_server.layers.linear import FastLinear
|
||||
|
||||
if SYSTEM == "rocm":
|
||||
return FastLinearROCm(self.weight, bias)
|
||||
else:
|
||||
return FastLinear(self.weight, bias)
|
||||
|
||||
|
||||
@ -197,7 +192,7 @@ class Weights:
|
||||
slice_ = f.get_slice(tensor_name)
|
||||
return slice_
|
||||
|
||||
def _has_tensor(self, tensor_name: str):
|
||||
def has_tensor(self, tensor_name: str):
|
||||
try:
|
||||
self.get_filename(tensor_name)
|
||||
except Exception:
|
||||
@ -207,7 +202,9 @@ class Weights:
|
||||
def get_shape(self, tensor_name: str):
|
||||
return self._get_slice(tensor_name).get_shape()
|
||||
|
||||
def get_tensor(self, tensor_name: str, to_device=True, to_dtype=True):
|
||||
def get_tensor(
|
||||
self, tensor_name: str, to_device: bool = True, to_dtype: bool = True
|
||||
) -> torch.Tensor:
|
||||
filename, tensor_name = self.get_filename(tensor_name)
|
||||
f = self._get_handle(filename)
|
||||
tensor = f.get_tensor(tensor_name)
|
||||
@ -218,6 +215,7 @@ class Weights:
|
||||
tensor.dtype
|
||||
not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
@ -253,7 +251,8 @@ class Weights:
|
||||
# u4 which are disguised as int32. exl2 uses int16.
|
||||
# FP8 uses torch.float8_e4m3fn.
|
||||
if (
|
||||
tensor.dtype not in (torch.float8_e4m3fn, torch.int16, torch.int32)
|
||||
tensor.dtype
|
||||
not in (torch.float8_e4m3fn, torch.int8, torch.int16, torch.int32)
|
||||
and to_dtype
|
||||
):
|
||||
tensor = tensor.to(dtype=self.dtype)
|
||||
@ -329,6 +328,7 @@ class Weights:
|
||||
tensor.dtype
|
||||
not in [
|
||||
torch.float8_e4m3fn,
|
||||
torch.int8,
|
||||
torch.int16,
|
||||
torch.int32,
|
||||
torch.int64,
|
||||
|
Loading…
Reference in New Issue
Block a user