mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add transformers_flash
This commit is contained in:
parent
9f5c9a5e22
commit
ade0f44aca
@ -20,6 +20,7 @@ from pathlib import Path
|
||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||
from text_generation_server.models.model import Model
|
||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||
from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM
|
||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||
MPTForCausalLM,
|
||||
@ -28,7 +29,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||
BloomForCausalLM,
|
||||
)
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING
|
||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||
@ -366,12 +367,38 @@ def get_model(
|
||||
max_input_tokens: int,
|
||||
) -> Model:
|
||||
global FLASH_ATTENTION
|
||||
global USE_CUSTOM_MODELING
|
||||
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
model_type = config_dict.get("model_type", None)
|
||||
|
||||
transformers_causal_lm_class = CausalLM
|
||||
if (
|
||||
not USE_CUSTOM_MODELING
|
||||
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
):
|
||||
logger.info(
|
||||
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
||||
)
|
||||
transformers_model_class = getattr(
|
||||
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||
)
|
||||
|
||||
if (
|
||||
transformers_model_class._supports_flash_attn_2
|
||||
and transformers_model_class._supports_cache_class
|
||||
):
|
||||
logger.info(
|
||||
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)."
|
||||
)
|
||||
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||
else:
|
||||
logger.info(
|
||||
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)."
|
||||
)
|
||||
|
||||
quantization_config = config_dict.get("quantization_config", None)
|
||||
if quantization_config is None:
|
||||
quantization_config = config_dict.get("compression_config", None)
|
||||
|
@ -67,3 +67,7 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
||||
def get_adapter_to_index():
|
||||
global ADAPTER_TO_INDEX
|
||||
return ADAPTER_TO_INDEX
|
||||
|
||||
|
||||
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
|
||||
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1"
|
||||
|
@ -0,0 +1,309 @@
|
||||
import math
|
||||
import sys
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
import torch
|
||||
from opentelemetry import trace
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import (
|
||||
empty_cache,
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||
from text_generation_server.layers.attention.kv_cache import KVScales
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
def patch_everywhere(
|
||||
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
|
||||
|
||||
Args:
|
||||
attribute_name (`str`):
|
||||
The name of attribute to patch.
|
||||
patch (`Any`):
|
||||
The patch for the attribute.
|
||||
module_name_prefix (`Optional[str]`, defaults to `None`):
|
||||
If set, only module names starting with this prefix will be considered for patching.
|
||||
"""
|
||||
# sys.modules may be updated while being iterated over, hence the list copy.
|
||||
for name in list(sys.modules):
|
||||
module = sys.modules[name]
|
||||
if module_name_prefix is not None and not name.startswith(module_name_prefix):
|
||||
continue
|
||||
if hasattr(module, attribute_name):
|
||||
setattr(module, attribute_name, patch)
|
||||
|
||||
|
||||
def _flash_attention_forward_patched(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
query_length: int,
|
||||
is_causal: bool,
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: int = -1,
|
||||
softcap: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]]
|
||||
# This means no scale
|
||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||
|
||||
# Correctly reshape the states
|
||||
_, _, num_heads, head_dim = query_states.size()
|
||||
_, _, num_kv_heads, _ = key_states.size()
|
||||
query_states = query_states.view(-1, num_heads, head_dim)
|
||||
key_states = key_states.view(-1, num_kv_heads, head_dim)
|
||||
value_states = value_states.view(-1, num_kv_heads, head_dim)
|
||||
|
||||
# Take care of updating the cache in-place
|
||||
kv_cache.store(
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
slots=kwargs["slots"],
|
||||
kv_scales=kv_scales
|
||||
)
|
||||
|
||||
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||
|
||||
if kwargs["cu_seqlen_prefill"] is not None:
|
||||
attn_output = attention(
|
||||
query=query_states,
|
||||
key=key_states,
|
||||
value=value_states,
|
||||
kv_cache=kv_cache,
|
||||
kv_scales=kv_scales,
|
||||
seqlen=kwargs["seqlen"],
|
||||
block_tables=kwargs["block_tables"],
|
||||
softmax_scale=softmax_scale,
|
||||
window_size_left=sliding_window,
|
||||
softcap=softcap,
|
||||
)
|
||||
else:
|
||||
attn_output = paged_attention(
|
||||
query_states,
|
||||
kv_cache,
|
||||
kwargs["kv_head_mapping"],
|
||||
softmax_scale,
|
||||
kwargs["block_tables"],
|
||||
kwargs["seqlen"],
|
||||
kwargs["max_s"],
|
||||
kv_scales=kv_scales,
|
||||
softcap=softcap,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class TransformersFlashCausalLM(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
speculator: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
if speculator:
|
||||
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||
|
||||
device_count = 0
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
device_count = torch.cuda.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
device = torch.device("xpu")
|
||||
device_count = torch.xpu.device_count()
|
||||
dtype = torch.float16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=("auto" if device_count > 1 else None),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if device_count == 1 and quantize != "bitsandbytes":
|
||||
model = model.to(device)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None and isinstance(
|
||||
model.config.eos_token_id, int
|
||||
):
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
|
||||
self.num_layers = len(model.model.layers)
|
||||
self.num_kv_heads = model.config.num_key_value_heads
|
||||
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||
|
||||
# Skip FlashCausalLM init.
|
||||
super(FlashCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def warmup(self, batch: FlashCausalLMBatch):
|
||||
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
||||
super().warmup(batch)
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# NOTE: adapter_data: not supported
|
||||
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, None
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, None
|
Loading…
Reference in New Issue
Block a user