mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
add transformers_flash
This commit is contained in:
parent
9f5c9a5e22
commit
ade0f44aca
@ -20,6 +20,7 @@ from pathlib import Path
|
|||||||
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
from text_generation_server.utils.speculate import get_speculate, set_speculate
|
||||||
from text_generation_server.models.model import Model
|
from text_generation_server.models.model import Model
|
||||||
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast
|
||||||
|
from text_generation_server.models.transformers_flash_causal_lm import TransformersFlashCausalLM
|
||||||
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM
|
||||||
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
from text_generation_server.models.custom_modeling.mpt_modeling import (
|
||||||
MPTForCausalLM,
|
MPTForCausalLM,
|
||||||
@ -28,7 +29,7 @@ from text_generation_server.models.bloom import BloomCausalLMBatch
|
|||||||
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
from text_generation_server.models.custom_modeling.bloom_modeling import (
|
||||||
BloomForCausalLM,
|
BloomForCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import ATTENTION
|
from text_generation_server.models.globals import ATTENTION, USE_CUSTOM_MODELING
|
||||||
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
|
||||||
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
from text_generation_server.models.galactica import GalacticaCausalLMBatch
|
||||||
from text_generation_server.models.custom_modeling.neox_modeling import (
|
from text_generation_server.models.custom_modeling.neox_modeling import (
|
||||||
@ -366,12 +367,38 @@ def get_model(
|
|||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
|
global USE_CUSTOM_MODELING
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
model_type = config_dict.get("model_type", None)
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
transformers_causal_lm_class = CausalLM
|
||||||
|
if (
|
||||||
|
not USE_CUSTOM_MODELING
|
||||||
|
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
|
||||||
|
)
|
||||||
|
transformers_model_class = getattr(
|
||||||
|
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
transformers_model_class._supports_flash_attn_2
|
||||||
|
and transformers_model_class._supports_cache_class
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)."
|
||||||
|
)
|
||||||
|
transformers_causal_lm_class = TransformersFlashCausalLM
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)."
|
||||||
|
)
|
||||||
|
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
if quantization_config is None:
|
if quantization_config is None:
|
||||||
quantization_config = config_dict.get("compression_config", None)
|
quantization_config = config_dict.get("compression_config", None)
|
||||||
|
@ -67,3 +67,7 @@ def set_adapter_to_index(adapter_to_index: Dict[str, int]):
|
|||||||
def get_adapter_to_index():
|
def get_adapter_to_index():
|
||||||
global ADAPTER_TO_INDEX
|
global ADAPTER_TO_INDEX
|
||||||
return ADAPTER_TO_INDEX
|
return ADAPTER_TO_INDEX
|
||||||
|
|
||||||
|
|
||||||
|
USE_CUSTOM_MODELING = os.getenv("USE_CUSTOM_MODELING", "true")
|
||||||
|
USE_CUSTOM_MODELING = USE_CUSTOM_MODELING == "true" or USE_CUSTOM_MODELING == "1"
|
||||||
|
@ -0,0 +1,309 @@
|
|||||||
|
import math
|
||||||
|
import sys
|
||||||
|
from typing import Optional, Tuple, Dict, Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from opentelemetry import trace
|
||||||
|
from loguru import logger
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
from text_generation_server.models.flash_causal_lm import (
|
||||||
|
FlashCausalLMBatch,
|
||||||
|
FlashCausalLM,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.import_utils import (
|
||||||
|
empty_cache,
|
||||||
|
synchronize,
|
||||||
|
get_free_memory,
|
||||||
|
)
|
||||||
|
from text_generation_server.adapters import AdapterBatchData
|
||||||
|
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||||
|
from text_generation_server.layers.attention.kv_cache import KVScales
|
||||||
|
from text_generation_server.models.globals import ATTENTION
|
||||||
|
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||||
|
|
||||||
|
|
||||||
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_everywhere(
|
||||||
|
attribute_name: str, patch: Any, module_name_prefix: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Finds all occurences of `attribute_name` in the loaded modules and patches them with `patch`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attribute_name (`str`):
|
||||||
|
The name of attribute to patch.
|
||||||
|
patch (`Any`):
|
||||||
|
The patch for the attribute.
|
||||||
|
module_name_prefix (`Optional[str]`, defaults to `None`):
|
||||||
|
If set, only module names starting with this prefix will be considered for patching.
|
||||||
|
"""
|
||||||
|
# sys.modules may be updated while being iterated over, hence the list copy.
|
||||||
|
for name in list(sys.modules):
|
||||||
|
module = sys.modules[name]
|
||||||
|
if module_name_prefix is not None and not name.startswith(module_name_prefix):
|
||||||
|
continue
|
||||||
|
if hasattr(module, attribute_name):
|
||||||
|
setattr(module, attribute_name, patch)
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_forward_patched(
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
is_causal: bool,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
sliding_window: int = -1,
|
||||||
|
softcap: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]]
|
||||||
|
# This means no scale
|
||||||
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
|
# Correctly reshape the states
|
||||||
|
_, _, num_heads, head_dim = query_states.size()
|
||||||
|
_, _, num_kv_heads, _ = key_states.size()
|
||||||
|
query_states = query_states.view(-1, num_heads, head_dim)
|
||||||
|
key_states = key_states.view(-1, num_kv_heads, head_dim)
|
||||||
|
value_states = value_states.view(-1, num_kv_heads, head_dim)
|
||||||
|
|
||||||
|
# Take care of updating the cache in-place
|
||||||
|
kv_cache.store(
|
||||||
|
key=key_states,
|
||||||
|
value=value_states,
|
||||||
|
slots=kwargs["slots"],
|
||||||
|
kv_scales=kv_scales
|
||||||
|
)
|
||||||
|
|
||||||
|
softmax_scale = 1 / math.sqrt(head_dim) if softmax_scale is None else softmax_scale
|
||||||
|
|
||||||
|
if kwargs["cu_seqlen_prefill"] is not None:
|
||||||
|
attn_output = attention(
|
||||||
|
query=query_states,
|
||||||
|
key=key_states,
|
||||||
|
value=value_states,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
seqlen=kwargs["seqlen"],
|
||||||
|
block_tables=kwargs["block_tables"],
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
window_size_left=sliding_window,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = paged_attention(
|
||||||
|
query_states,
|
||||||
|
kv_cache,
|
||||||
|
kwargs["kv_head_mapping"],
|
||||||
|
softmax_scale,
|
||||||
|
kwargs["block_tables"],
|
||||||
|
kwargs["seqlen"],
|
||||||
|
kwargs["max_s"],
|
||||||
|
kv_scales=kv_scales,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersFlashCausalLM(FlashCausalLM):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
quantize: Optional[str] = None,
|
||||||
|
speculator: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
):
|
||||||
|
if speculator:
|
||||||
|
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
|
||||||
|
|
||||||
|
device_count = 0
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda")
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
device = torch.device("xpu")
|
||||||
|
device_count = torch.xpu.device_count()
|
||||||
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization is not available on CPU")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
padding_side="left",
|
||||||
|
truncation_side="left",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
revision=revision,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map=("auto" if device_count > 1 else None),
|
||||||
|
load_in_8bit=quantize == "bitsandbytes",
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
if device_count == 1 and quantize != "bitsandbytes":
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if tokenizer.pad_token_id is None:
|
||||||
|
if model.config.pad_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = model.config.pad_token_id
|
||||||
|
elif model.config.eos_token_id is not None and isinstance(
|
||||||
|
model.config.eos_token_id, int
|
||||||
|
):
|
||||||
|
tokenizer.pad_token_id = model.config.eos_token_id
|
||||||
|
elif tokenizer.eos_token_id is not None:
|
||||||
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||||
|
else:
|
||||||
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||||
|
|
||||||
|
|
||||||
|
self.num_layers = len(model.model.layers)
|
||||||
|
self.num_kv_heads = model.config.num_key_value_heads
|
||||||
|
self.head_size = model.config.hidden_size // model.config.num_attention_heads
|
||||||
|
|
||||||
|
# Skip FlashCausalLM init.
|
||||||
|
super(FlashCausalLM, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
requires_padding=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def warmup(self, batch: FlashCausalLMBatch):
|
||||||
|
patch_everywhere("_flash_attention_forward", _flash_attention_forward_patched)
|
||||||
|
super().warmup(batch)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
# NOTE: adapter_data: not supported
|
||||||
|
|
||||||
|
input_ids = batch.input_ids
|
||||||
|
position_ids = batch.position_ids
|
||||||
|
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||||
|
kv_cache = self.kv_cache
|
||||||
|
block_tables = batch.block_tables_tensor
|
||||||
|
slots = batch.slots[batch.slot_indices]
|
||||||
|
input_lengths = batch.input_lengths_tensor
|
||||||
|
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||||
|
max_s = batch.max_current_length
|
||||||
|
lm_head_indices = batch.prefill_head_indices
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||||
|
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||||
|
# in a circular buffer mode.
|
||||||
|
# This makes sure the max_s for the decode pass is correct.
|
||||||
|
max_s = min(self.max_past(), max_s)
|
||||||
|
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||||
|
if sorted_padded_bs:
|
||||||
|
# Get associated cuda graph
|
||||||
|
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||||
|
else:
|
||||||
|
cuda_graph = None
|
||||||
|
|
||||||
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
|
)
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
input_lengths_tensor=input_lengths,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=batch.max_input_length,
|
||||||
|
max_k=batch.max_current_length,
|
||||||
|
)
|
||||||
|
logits = self.model.forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
|
lm_head_indices=lm_head_indices,
|
||||||
|
)
|
||||||
|
if batch.prefill_cache_indices is not None:
|
||||||
|
batch.prefill_cache_indices = None
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
# Copy inputs to the static inputs of the cuda graph
|
||||||
|
# Static inputs are potentially padded
|
||||||
|
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||||
|
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=batch.input_lengths,
|
||||||
|
cache_lengths=batch.cache_lengths,
|
||||||
|
input_lengths_tensor=batch.input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||||
|
max_current_length=batch.max_current_length,
|
||||||
|
)
|
||||||
|
# assert block_tables.shape[0] >= slots.shape[0]
|
||||||
|
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||||
|
else:
|
||||||
|
cuda_graph["block_tables"][
|
||||||
|
: block_tables.shape[0], : block_tables.shape[1]
|
||||||
|
] = block_tables
|
||||||
|
|
||||||
|
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||||
|
# so it doesn't matter if we override it with bogus values.
|
||||||
|
cuda_graph["slots"].fill_(0)
|
||||||
|
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||||
|
cuda_graph["input_lengths"].zero_()
|
||||||
|
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||||
|
cuda_graph["cache_lengths"].zero_()
|
||||||
|
cuda_graph["cache_lengths"][
|
||||||
|
: cache_lengths_tensor.shape[0]
|
||||||
|
] = cache_lengths_tensor
|
||||||
|
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=cuda_graph["block_tables"],
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||||
|
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||||
|
state=cuda_graph["state"],
|
||||||
|
):
|
||||||
|
# Replay the graph
|
||||||
|
cuda_graph["graph"].replay()
|
||||||
|
|
||||||
|
# Slice output to the correct shape
|
||||||
|
logits = cuda_graph["logits"][:bs]
|
||||||
|
return logits, None
|
Loading…
Reference in New Issue
Block a user