mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-10 20:04:52 +00:00
wip
This commit is contained in:
parent
259a230028
commit
2811ec9bff
@ -67,6 +67,16 @@ if FLASH_ATTENTION:
|
||||
__all__.append(FlashLlama)
|
||||
__all__.append(IDEFICSSharded)
|
||||
|
||||
MISTRAL = True
|
||||
try:
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Mistral model: {e}")
|
||||
MISTRAL = False
|
||||
|
||||
if MISTRAL:
|
||||
__all__.append(FlashMistral)
|
||||
|
||||
|
||||
def get_model(
|
||||
model_id: str,
|
||||
@ -237,7 +247,18 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "opt":
|
||||
if model_type == "mistral":
|
||||
if MISTRAL:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
raise NotImplementedError("Mistral model requires flash attention v2")
|
||||
|
||||
if model_type == "opt":
|
||||
return OPTSharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -246,7 +267,7 @@ def get_model(
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
elif model_type == "t5":
|
||||
if model_type == "t5":
|
||||
return T5Sharded(
|
||||
model_id,
|
||||
revision,
|
||||
@ -254,7 +275,7 @@ def get_model(
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == "idefics":
|
||||
if model_type == "idefics":
|
||||
if FLASH_ATTENTION:
|
||||
return IDEFICSSharded(
|
||||
model_id,
|
||||
|
135
server/text_generation_server/models/cache_manager.py
Normal file
135
server/text_generation_server/models/cache_manager.py
Normal file
@ -0,0 +1,135 @@
|
||||
import math
|
||||
import torch
|
||||
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
BLOCK_SIZE: int = 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(
|
||||
self,
|
||||
needed_blocks_slots: List[Tuple[int, int]],
|
||||
blocks: int,
|
||||
max_blocks: int,
|
||||
device: torch.device,
|
||||
):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= blocks
|
||||
), f"Out of available cache blocks: asked {blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[:blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(needed_blocks_slots), max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
all_slots = self.slots[allocated_blocks].flatten()
|
||||
|
||||
# Repeat slots in the case of context sliding window
|
||||
if needed_slots > len(all_slots) and self.repeat_slots:
|
||||
repeats = math.ceil(needed_slots / len(all_slots))
|
||||
all_slots = all_slots.repeat(repeats)
|
||||
|
||||
allocated_slots = all_slots[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
block_tables = block_tables
|
||||
block_tables_tensor = block_tables_tensor.to(device)
|
||||
slots = torch.concat(slots).to(device)
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
return block_tables, block_tables_tensor, slots
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
def set_cache_manager(
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
repeat_slots: bool,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> CacheManager:
|
||||
global CACHE_MANAGER
|
||||
if CACHE_MANAGER is not None:
|
||||
del CACHE_MANAGER
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
num_blocks, num_layers, num_heads, head_size, repeat_slots, dtype, device
|
||||
)
|
||||
return CACHE_MANAGER
|
||||
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
global CACHE_MANAGER
|
||||
if CACHE_MANAGER is None:
|
||||
raise RuntimeError("cache manager was not initialized")
|
||||
|
||||
return CACHE_MANAGER
|
@ -0,0 +1,532 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
# Flash attention imports
|
||||
import dropout_layer_norm
|
||||
|
||||
# vllm imports
|
||||
import vllm_cache_ops
|
||||
import vllm_attention_ops
|
||||
|
||||
from text_generation_server.utils.flash_attn import attention, HAS_FLASH_ATTN_V2
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelRowLinear,
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
PositionRotaryEmbedding,
|
||||
TensorParallelHead,
|
||||
get_linear,
|
||||
)
|
||||
|
||||
if not HAS_FLASH_ATTN_V2:
|
||||
raise ImportError("Mistral model requires flash attn v2")
|
||||
|
||||
|
||||
class MistralConfig(PretrainedConfig):
|
||||
model_type = "mistral"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=32000,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096 * 32,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.pretraining_tp = pretraining_tp
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class MistralRMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps=1e-6):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
weight = weights.get_tensor(f"{prefix}.weight")
|
||||
self.weight = nn.Parameter(weight)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 8192:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states, residual
|
||||
else:
|
||||
# faster post attention rms norm
|
||||
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.variance_epsilon,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
True, # Activate RMSNorm
|
||||
)
|
||||
if res is None:
|
||||
res = hidden_states
|
||||
|
||||
return normed_hidden_states, res
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
if config.num_attention_heads != config.num_key_value_heads:
|
||||
return _load_gqa(config, prefix, weights)
|
||||
else:
|
||||
return TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
|
||||
def _load_gqa(config, prefix: str, weights):
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||
|
||||
weight = weights.get_multi_weights_col(
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
quantize=config.quantize,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if config.quantize not in ["gptq", "awq"]:
|
||||
weight = weight.to(dtype=weights.dtype).to(device=weights.device)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
num_heads = config.num_attention_heads // weights.process_group.size()
|
||||
num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
|
||||
assert list(weight.shape) == [
|
||||
(num_heads + 2 * num_key_value_heads) * head_size,
|
||||
config.hidden_size,
|
||||
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"
|
||||
|
||||
return TensorParallelColumnLinear(
|
||||
get_linear(weight, bias=None, quantize=config.quantize)
|
||||
)
|
||||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else 0
|
||||
)
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.num_heads
|
||||
|
||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||
config=config,
|
||||
dim=self.head_size,
|
||||
base=config.rope_theta,
|
||||
device=weights.device,
|
||||
)
|
||||
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
raise ValueError(
|
||||
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
|
||||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
2 * self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
self.rotary_emb(query, cos, sin)
|
||||
self.rotary_emb(torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
kv_to_cache = kv
|
||||
|
||||
vllm_cache_ops.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
attention(
|
||||
query,
|
||||
torch.select(kv, dim=1, index=0),
|
||||
torch.select(kv, dim=1, index=1),
|
||||
attn_output,
|
||||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
max_past=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
# kv_cache[1] => [num_blocks, num_heads, head_size, block_size]
|
||||
block_size = kv_cache[1].shape[3]
|
||||
vllm_attention_ops.single_query_cached_kv_attention(
|
||||
attn_output,
|
||||
query,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
self.kv_head_mapping,
|
||||
self.softmax_scale,
|
||||
block_tables,
|
||||
input_lengths,
|
||||
block_size,
|
||||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate="tanh"
|
||||
if act in ["gelu_fast", "gelu_pytorch_tanh"]
|
||||
else "none",
|
||||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, layer_id, config, weights):
|
||||
super().__init__()
|
||||
prefix = f"model.layers.{layer_id}"
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = MistralRMSNorm(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = MistralRMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
normed_hidden_states,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
||||
class MistralModel(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
self.tp_rank = process_group.rank()
|
||||
self.tp_world_size = process_group.size()
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix="model.embed_tokens", weights=weights
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MistralLayer(
|
||||
layer_id,
|
||||
config,
|
||||
weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = MistralRMSNorm(
|
||||
prefix="model.norm", weights=weights, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Get rotary cos and sin for this forward
|
||||
# Avoid to index in each layer
|
||||
cos, sin = self.layers[0].self_attn.rotary_emb.get_cos_sin(
|
||||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache[i],
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.model = MistralModel(config, weights)
|
||||
self.lm_head = TensorParallelHead.load(
|
||||
config,
|
||||
prefix="lm_head",
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
if self.max_past is None:
|
||||
raise ValueError("max_past cannot be None")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if prefill_cache_indices is not None:
|
||||
# Slots also need to be sliced as it has the same size as the whole kv tensor
|
||||
slots = slots[prefill_cache_indices]
|
||||
else:
|
||||
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
|
||||
# kernel requires the true values
|
||||
max_s = min(self.max_past, max_s)
|
||||
input_lengths = torch.clamp(input_lengths, max=self.max_past)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
@ -19,99 +19,17 @@ from text_generation_server.models.types import (
|
||||
GeneratedText,
|
||||
TopTokens,
|
||||
)
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
set_cache_manager,
|
||||
BLOCK_SIZE,
|
||||
)
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
|
||||
from text_generation_server.utils.dist import MEMORY_FRACTION
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
# Will be set in warmup
|
||||
CACHE_MANAGER: Optional["CacheManager"] = None
|
||||
|
||||
|
||||
class CacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
self.block_size = BLOCK_SIZE
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
element_size = torch.tensor([], dtype=dtype).element_size()
|
||||
x = self.block_size // element_size
|
||||
|
||||
self.kv_cache = [
|
||||
(
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size // x, self.block_size, x),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
torch.empty(
|
||||
(num_blocks, num_heads, head_size, self.block_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
self.free_block_mask = torch.ones(num_blocks, dtype=torch.int32, device="cpu")
|
||||
self.slots = torch.arange(
|
||||
0, num_blocks * self.block_size, dtype=torch.int32
|
||||
).view(num_blocks, self.block_size)
|
||||
|
||||
def allocate(self, batch: "FlashCausalLMBatch"):
|
||||
# Get free blocks indices by finding values in mask that are not set to 0
|
||||
free_block_indices = self.free_block_mask.nonzero()
|
||||
assert (
|
||||
len(free_block_indices) >= batch.blocks
|
||||
), f"Out of available cache blocks: asked {batch.blocks}, only {len(free_block_indices)} free blocks"
|
||||
|
||||
# Slice by the number of required blocks
|
||||
block_indices = free_block_indices[: batch.blocks]
|
||||
block_indices = block_indices.flatten()
|
||||
|
||||
# Padded block tables
|
||||
block_tables_tensor = torch.zeros(
|
||||
(len(batch), batch.max_blocks), dtype=torch.int32
|
||||
)
|
||||
|
||||
# Allocate paged attention blocks
|
||||
cumulative_blocks = 0
|
||||
slots = []
|
||||
block_tables = []
|
||||
for i, (needed_blocks, needed_slots) in enumerate(batch.needed_blocks_slots):
|
||||
# Get allocated blocks for this sequence
|
||||
allocated_blocks = block_indices[
|
||||
cumulative_blocks : cumulative_blocks + needed_blocks
|
||||
]
|
||||
# Get slots for the allocated blocks
|
||||
allocated_slots = self.slots[allocated_blocks].flatten()[:needed_slots]
|
||||
|
||||
slots.append(allocated_slots)
|
||||
block_tables.append(allocated_blocks.tolist())
|
||||
block_tables_tensor[i, :needed_blocks] = allocated_blocks
|
||||
cumulative_blocks += needed_blocks
|
||||
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor.to(batch.input_ids.device)
|
||||
batch.slots = torch.concat(slots).to(batch.input_ids.device)
|
||||
|
||||
# Allocate the required number of blocks by setting the mask to 0
|
||||
self.free_block_mask[block_indices] = 0
|
||||
|
||||
def free(self, block_indices: Optional[List[int]]):
|
||||
if block_indices is not None and block_indices:
|
||||
# Reset mask
|
||||
self.free_block_mask[block_indices] = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashCausalLMBatch(Batch):
|
||||
@ -481,7 +399,6 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
max_blocks = max(max_blocks, len(request_block_table))
|
||||
|
||||
global CACHE_MANAGER
|
||||
block_indices_to_free = []
|
||||
# Iterate on all requests
|
||||
for i, r in enumerate(self.requests):
|
||||
@ -489,7 +406,7 @@ class FlashCausalLMBatch(Batch):
|
||||
if r.id not in requests_idx_mapping.keys():
|
||||
block_indices_to_free.extend(self.block_tables[i])
|
||||
# Free blocks
|
||||
CACHE_MANAGER.free(block_indices_to_free)
|
||||
get_cache_manager().free(block_indices_to_free)
|
||||
# Needed to avoid dropping blocks when the batches will go out of scope
|
||||
self.block_tables = None
|
||||
|
||||
@ -508,7 +425,7 @@ class FlashCausalLMBatch(Batch):
|
||||
# Move to GPU now that we have the whole tensor
|
||||
slot_indices = slot_indices.to(device)
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
return type(self)(
|
||||
batch_id=self.batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
@ -665,7 +582,7 @@ class FlashCausalLMBatch(Batch):
|
||||
b.block_tables = None
|
||||
del b
|
||||
|
||||
return FlashCausalLMBatch(
|
||||
return cls(
|
||||
batch_id=batches[0].batch_id,
|
||||
requests=requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
@ -698,9 +615,10 @@ class FlashCausalLMBatch(Batch):
|
||||
|
||||
def __del__(self):
|
||||
if self.block_tables is not None and self.block_tables:
|
||||
global CACHE_MANAGER
|
||||
# Free blocks
|
||||
CACHE_MANAGER.free(list(itertools.chain.from_iterable(self.block_tables)))
|
||||
get_cache_manager().free(
|
||||
list(itertools.chain.from_iterable(self.block_tables))
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.requests)
|
||||
@ -718,10 +636,12 @@ class FlashCausalLM(Model):
|
||||
device: torch.device,
|
||||
rank: int = 0,
|
||||
world_size: int = 1,
|
||||
repeat_slots: bool = False,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_size = head_size
|
||||
self.repeat_slots = repeat_slots
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model=model,
|
||||
@ -738,15 +658,14 @@ class FlashCausalLM(Model):
|
||||
return FlashCausalLMBatch
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
global CACHE_MANAGER
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
CACHE_MANAGER = CacheManager(
|
||||
cache_manager = set_cache_manager(
|
||||
batch.blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.repeat_slots,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
@ -775,48 +694,36 @@ class FlashCausalLM(Model):
|
||||
num_blocks = (
|
||||
int(free_memory // total_cache_size)
|
||||
# Add batch.blocks as we allocated it above, so it is included in the peak memory.
|
||||
+ CACHE_MANAGER.num_blocks
|
||||
+ cache_manager.num_blocks
|
||||
)
|
||||
|
||||
del CACHE_MANAGER
|
||||
del batch
|
||||
torch.cuda.empty_cache()
|
||||
del cache_manager
|
||||
|
||||
CACHE_MANAGER = CacheManager(
|
||||
set_cache_manager(
|
||||
num_blocks,
|
||||
self.num_layers,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.repeat_slots,
|
||||
self.dtype,
|
||||
self.device,
|
||||
)
|
||||
|
||||
return int(num_blocks * BLOCK_SIZE)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
global CACHE_MANAGER
|
||||
|
||||
def forward(self, batch: FlashCausalLMBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
return self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=CACHE_MANAGER.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
lm_head_indices=lm_head_indices,
|
||||
input_ids=batch.input_ids,
|
||||
position_ids=batch.position_ids,
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=batch.block_tables_tensor,
|
||||
slots=batch.slots[batch.slot_indices],
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
@ -828,19 +735,19 @@ class FlashCausalLM(Model):
|
||||
|
||||
if batch.needed_blocks_slots:
|
||||
# Allocate blocks to this batch
|
||||
CACHE_MANAGER.allocate(batch)
|
||||
block_tables, block_tables_tensor, slots = get_cache_manager().allocate(
|
||||
batch.needed_blocks_slots,
|
||||
batch.blocks,
|
||||
batch.max_blocks,
|
||||
batch.input_ids.device,
|
||||
)
|
||||
batch.needed_blocks_slots = None
|
||||
batch.block_tables = block_tables
|
||||
batch.block_tables_tensor = block_tables_tensor
|
||||
batch.slots = slots
|
||||
|
||||
try:
|
||||
out = self.forward(
|
||||
batch.input_ids,
|
||||
batch.position_ids,
|
||||
batch.cu_seqlen_prefill,
|
||||
batch.block_tables_tensor,
|
||||
batch.slots[batch.slot_indices],
|
||||
batch.input_lengths_tensor,
|
||||
batch.max_seqlen,
|
||||
batch.prefill_head_indices,
|
||||
)
|
||||
out = self.forward(batch)
|
||||
except Exception as e:
|
||||
del batch
|
||||
raise e
|
||||
|
357
server/text_generation_server/models/flash_mistral.py
Normal file
357
server/text_generation_server/models/flash_mistral.py
Normal file
@ -0,0 +1,357 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import numpy as np
|
||||
|
||||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from transformers.models.llama import LlamaTokenizerFast
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from text_generation_server.pb import generate_pb2
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch, BLOCK_SIZE
|
||||
from text_generation_server.models.cache_manager import (
|
||||
get_cache_manager,
|
||||
set_cache_manager,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
MistralConfig,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
HeterogeneousNextTokenChooser,
|
||||
StoppingCriteria,
|
||||
)
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
# Will be set in init
|
||||
SLIDING_WINDOW: Optional[int] = None
|
||||
SLIDING_WINDOW_BLOCKS: Optional[int] = None
|
||||
|
||||
|
||||
# Adds windowing logic to FlashCausalLMBatch
|
||||
@dataclass
|
||||
class FlashMistralBatch(FlashCausalLMBatch):
|
||||
# Prefill cache indices is used to slice into the kv tensor before caching it into the paged attention buffers
|
||||
# as we only keep SLIDING_WINDOW values instead of the whole tensor
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "FlashCausalLMBatch":
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
|
||||
batch_inputs = []
|
||||
max_truncation = 0
|
||||
for r in pb.requests:
|
||||
batch_inputs.append(r.inputs)
|
||||
max_truncation = max(max_truncation, r.truncate)
|
||||
|
||||
batch_tokenized_inputs = tokenizer(
|
||||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
|
||||
position_ids = []
|
||||
cu_seqlen_prefill = [0]
|
||||
needed_blocks_slots = []
|
||||
start_slots = []
|
||||
slot_indices = []
|
||||
prefill_cache_indices = []
|
||||
|
||||
input_lengths = []
|
||||
prefix_offsets = []
|
||||
read_offsets = []
|
||||
all_input_ids = []
|
||||
requests_idx_mapping = {}
|
||||
|
||||
all_prefill_logprobs = True
|
||||
no_prefill_logprobs = True
|
||||
prefill_head_indices = []
|
||||
prefill_next_token_indices = []
|
||||
prefill_cu_outlens = [0]
|
||||
|
||||
next_token_chooser_parameters = []
|
||||
stopping_criterias = []
|
||||
top_n_tokens = []
|
||||
|
||||
# Cumulative length
|
||||
cumulative_length = 0
|
||||
cumulative_max_length = 0
|
||||
prefill_out_cumulative_length = 0
|
||||
|
||||
blocks = 0
|
||||
max_seqlen = 0
|
||||
max_length = 0
|
||||
max_blocks = 0
|
||||
|
||||
# Parse batch
|
||||
for i, (r, tokenized_input) in enumerate(
|
||||
zip(pb.requests, batch_tokenized_inputs)
|
||||
):
|
||||
# request id -> idx in list mapping
|
||||
requests_idx_mapping[r.id] = i
|
||||
|
||||
tokenized_input = tokenized_input[-r.truncate :]
|
||||
|
||||
input_length = len(tokenized_input)
|
||||
input_lengths.append(input_length)
|
||||
|
||||
prefix_offsets.append(input_length - 5)
|
||||
read_offsets.append(input_length)
|
||||
|
||||
all_input_ids.append(tokenized_input)
|
||||
|
||||
# Position ids
|
||||
request_position_ids = torch.arange(0, input_length, dtype=torch.int32)
|
||||
position_ids.append(request_position_ids)
|
||||
|
||||
# Add cumulative lengths of all previous inputs
|
||||
cu_seqlen_prefill.append(cumulative_length + input_length)
|
||||
|
||||
next_token_chooser_parameters.append(r.parameters)
|
||||
|
||||
stopping_criteria = StoppingCriteria.from_pb(
|
||||
r.stopping_parameters, tokenizer
|
||||
)
|
||||
max_new_tokens = stopping_criteria.max_new_tokens
|
||||
stopping_criterias.append(stopping_criteria)
|
||||
top_n_tokens.append(r.top_n_tokens)
|
||||
|
||||
# Paged attention
|
||||
# Remove one as the first token des not have a past
|
||||
total_tokens = input_length + max_new_tokens - 1
|
||||
|
||||
# Needed blocks can not go over SLIDING_WINDOW_BLOCKS
|
||||
needed_blocks = min(
|
||||
math.ceil(total_tokens / BLOCK_SIZE), SLIDING_WINDOW_BLOCKS
|
||||
)
|
||||
blocks += needed_blocks
|
||||
|
||||
needed_blocks_slots.append((needed_blocks, total_tokens))
|
||||
start_slots.append(cumulative_max_length)
|
||||
|
||||
request_slot_indices = torch.arange(
|
||||
cumulative_max_length,
|
||||
cumulative_max_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
slot_indices.append(request_slot_indices)
|
||||
|
||||
# Create tensor to slice into the kv tensor in prefill
|
||||
request_prefill_cache_indices = torch.arange(
|
||||
cumulative_length + max(0, input_length - SLIDING_WINDOW),
|
||||
cumulative_length + input_length,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
prefill_cache_indices.append(request_prefill_cache_indices)
|
||||
|
||||
all_prefill_logprobs = all_prefill_logprobs and r.prefill_logprobs
|
||||
no_prefill_logprobs = no_prefill_logprobs and not r.prefill_logprobs
|
||||
|
||||
if r.prefill_logprobs:
|
||||
prefill_head_indices.append(request_position_ids + cumulative_length)
|
||||
prefill_next_token_indices.append(
|
||||
prefill_out_cumulative_length + input_length - 1
|
||||
)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + input_length)
|
||||
prefill_out_cumulative_length += input_length
|
||||
else:
|
||||
prefill_head_indices.append(
|
||||
torch.tensor(
|
||||
[cumulative_length + input_length - 1], dtype=torch.int32
|
||||
)
|
||||
)
|
||||
prefill_next_token_indices.append(prefill_out_cumulative_length)
|
||||
prefill_cu_outlens.append(prefill_out_cumulative_length + 1)
|
||||
prefill_out_cumulative_length += 1
|
||||
|
||||
# Update
|
||||
cumulative_length += input_length
|
||||
cumulative_max_length += total_tokens
|
||||
max_seqlen = max(max_seqlen, input_length)
|
||||
max_blocks = max(max_blocks, needed_blocks)
|
||||
max_length = max(max_length, input_length + max_new_tokens)
|
||||
|
||||
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
|
||||
next_token_chooser_parameters, dtype, device
|
||||
)
|
||||
start_slots = torch.tensor(start_slots, dtype=torch.int64)
|
||||
|
||||
# Padded all_input_ids_tensor
|
||||
all_input_ids_tensor = np.zeros(
|
||||
(len(all_input_ids), max_length), dtype=np.int64
|
||||
)
|
||||
for i, input_ids in enumerate(all_input_ids):
|
||||
all_input_ids_tensor[i, : len(input_ids)] = input_ids
|
||||
|
||||
# Create tensors on device
|
||||
all_input_ids_tensor = torch.tensor(
|
||||
all_input_ids_tensor, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if len(pb.requests) > 1:
|
||||
input_ids = np.concatenate(all_input_ids, dtype=np.int64)
|
||||
position_ids = torch.cat(position_ids)
|
||||
slot_indices = torch.cat(slot_indices)
|
||||
prefill_cache_indices = torch.cat(prefill_cache_indices)
|
||||
else:
|
||||
input_ids = all_input_ids[0]
|
||||
position_ids = position_ids[0]
|
||||
slot_indices = slot_indices[0]
|
||||
prefill_cache_indices = prefill_cache_indices[0]
|
||||
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
cu_seqlen_prefill, device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
position_ids = position_ids.to(device)
|
||||
slot_indices = slot_indices.to(device)
|
||||
prefill_cache_indices = prefill_cache_indices.to(device)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, device=device)
|
||||
input_lengths_tensor = torch.tensor(
|
||||
input_lengths, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
if all_prefill_logprobs:
|
||||
prefill_head_indices = None
|
||||
prefill_next_token_indices = cu_seqlen_prefill[1:] - 1
|
||||
elif no_prefill_logprobs:
|
||||
prefill_head_indices = cu_seqlen_prefill[1:] - 1
|
||||
prefill_next_token_indices = None
|
||||
else:
|
||||
prefill_head_indices = torch.tensor(
|
||||
torch.cat(prefill_head_indices), dtype=torch.int64, device=device
|
||||
)
|
||||
prefill_next_token_indices = torch.tensor(
|
||||
prefill_next_token_indices, dtype=torch.int64, device=device
|
||||
)
|
||||
top_n_tokens_tensor = torch.tensor(
|
||||
top_n_tokens, device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
return cls(
|
||||
batch_id=pb.id,
|
||||
requests=pb.requests,
|
||||
requests_idx_mapping=requests_idx_mapping,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
start_slots=start_slots,
|
||||
slot_indices=slot_indices,
|
||||
needed_blocks_slots=needed_blocks_slots,
|
||||
block_tables=None,
|
||||
block_tables_tensor=None,
|
||||
slots=None,
|
||||
max_seqlen=max_seqlen,
|
||||
prefill_head_indices=prefill_head_indices,
|
||||
prefill_next_token_indices=prefill_next_token_indices,
|
||||
prefill_cu_outlens=prefill_cu_outlens,
|
||||
input_lengths=input_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
prefix_offsets=prefix_offsets,
|
||||
read_offsets=read_offsets,
|
||||
all_input_ids=all_input_ids,
|
||||
all_input_ids_tensor=all_input_ids_tensor,
|
||||
next_token_chooser=next_token_chooser,
|
||||
stopping_criterias=stopping_criterias,
|
||||
top_n_tokens=top_n_tokens,
|
||||
top_n_tokens_tensor=top_n_tokens_tensor,
|
||||
blocks=blocks,
|
||||
max_blocks=max_blocks,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
)
|
||||
|
||||
|
||||
class FlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
global SLIDING_WINDOW
|
||||
global SLIDING_WINDOW_BLOCKS
|
||||
|
||||
self.process_group, rank, world_size = initialize_torch_distributed()
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
config = MistralConfig.from_pretrained(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
config.quantize = quantize
|
||||
|
||||
# Set context windows
|
||||
SLIDING_WINDOW = config.sliding_window
|
||||
SLIDING_WINDOW_BLOCKS = math.ceil(config.sliding_window / BLOCK_SIZE)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
|
||||
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||||
weights = Weights(filenames, device, dtype, process_group=self.process_group)
|
||||
if config.quantize in ["gptq", "awq"]:
|
||||
weights._set_gptq_params(model_id)
|
||||
|
||||
model = FlashMistralForCausalLM(config, weights)
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashMistral, self).__init__(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_layers=len(model.model.layers),
|
||||
num_kv_heads=model.model.num_key_value_heads,
|
||||
head_size=model.model.head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
repeat_slots=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[FlashMistralBatch]:
|
||||
return FlashMistralBatch
|
||||
|
||||
def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Model Forward
|
||||
logits = self.model.forward(
|
||||
input_ids=batch.input_ids,
|
||||
position_ids=batch.position_ids,
|
||||
cu_seqlen_prefill=batch.cu_seqlen_prefill,
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=batch.block_tables_tensor,
|
||||
slots=batch.slots[batch.slot_indices],
|
||||
input_lengths=batch.input_lengths_tensor,
|
||||
max_s=batch.max_seqlen,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=batch.prefill_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits
|
@ -57,6 +57,7 @@ def attention(
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
softmax_scale,
|
||||
max_past=0,
|
||||
):
|
||||
if HAS_FLASH_ATTN_V2:
|
||||
return flash_attn_2_cuda.varlen_fwd(
|
||||
@ -72,11 +73,15 @@ def attention(
|
||||
softmax_scale,
|
||||
False,
|
||||
True,
|
||||
max_past,
|
||||
False,
|
||||
None,
|
||||
)
|
||||
|
||||
if HAS_FLASH_ATTN:
|
||||
if max_past != 0:
|
||||
raise NotImplementedError("max_past is only available with flash attn v2")
|
||||
|
||||
# Flash attention v1 requires q, k and v to have the same number of heads
|
||||
if k.shape[1] != q.shape[1]:
|
||||
# MQA expand
|
||||
|
Loading…
Reference in New Issue
Block a user