mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
Simplify with monkey patch
This commit is contained in:
parent
2659b5998b
commit
a2fe842795
@ -1,32 +1,17 @@
|
||||
import math
|
||||
import sys
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from opentelemetry import trace
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
|
||||
import transformers.modeling_utils
|
||||
|
||||
from text_generation_server.models.flash_causal_lm import (
|
||||
FlashCausalLMBatch,
|
||||
FlashCausalLM,
|
||||
)
|
||||
from text_generation_server.utils import (
|
||||
initialize_torch_distributed,
|
||||
weight_files,
|
||||
Weights,
|
||||
)
|
||||
from text_generation_server.utils.import_utils import (
|
||||
empty_cache,
|
||||
synchronize,
|
||||
get_free_memory,
|
||||
)
|
||||
from text_generation_server.adapters import AdapterBatchData
|
||||
from text_generation_server.models.flash_causal_lm import FlashCausalLM
|
||||
from text_generation_server.utils import initialize_torch_distributed
|
||||
|
||||
from text_generation_server.layers.attention import paged_attention, attention, Seqlen
|
||||
from text_generation_server.layers.attention.kv_cache import KVScales, KVCache
|
||||
from text_generation_server.models.globals import ATTENTION
|
||||
from text_generation_server.models.metadata_kernels import block_tables_to_ragged
|
||||
|
||||
|
||||
tracer = trace.get_tracer(__name__)
|
||||
@ -48,7 +33,7 @@ def tgi_flash_attention_forward(
|
||||
softmax_scale: Optional[float] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
softcap: Optional[float] = None,
|
||||
**_kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
**kwargs, # This is needed to "absorb" other args passed by Transformers modeling
|
||||
):
|
||||
|
||||
kv_cache = kv_cache[module.layer_idx]
|
||||
@ -222,6 +207,11 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
# Monkey patch of `self.model.forward` to match `FlashCausalLM`. It avoids duplicating a lot of code
|
||||
# We first copy the original model.forward because we still need it in the monkey patch
|
||||
self.model.original_forward = self.model.forward
|
||||
self.model.forward = self._model_forward
|
||||
|
||||
@classmethod
|
||||
def fallback(
|
||||
cls,
|
||||
@ -252,12 +242,15 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
seqlen: Seqlen,
|
||||
max_s: int,
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
prefill_cache_indices = None, # not used, but passed to match original signature
|
||||
adapter_data = None, # not supported, but passed to match original signature
|
||||
):
|
||||
# Transformers does not support None as a default
|
||||
if lm_head_indices is None:
|
||||
lm_head_indices = 0
|
||||
|
||||
logits = self.model.forward(
|
||||
# Equivalent tp `self.model.forward`, see the monkey patch in __init__
|
||||
logits = self.model.original_forward(
|
||||
input_ids=input_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||
position_ids=position_ids.unsqueeze(0), # expand dim to easily fit transformers
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
@ -272,292 +265,5 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
max_s=max_s,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
).logits.squeeze(dim=0)
|
||||
return logits
|
||||
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# NOTE: adapter_data: not supported
|
||||
|
||||
input_ids = batch.input_ids
|
||||
position_ids = batch.position_ids
|
||||
cu_seqlen_prefill = batch.cu_seqlen_prefill
|
||||
kv_cache = self.kv_cache
|
||||
block_tables = batch.block_tables_tensor
|
||||
slots = batch.slots[batch.slot_indices]
|
||||
input_lengths = batch.input_lengths_tensor
|
||||
cache_lengths_tensor = batch.cache_lengths_tensor
|
||||
max_s = batch.max_current_length
|
||||
lm_head_indices = batch.prefill_head_indices
|
||||
|
||||
if cu_seqlen_prefill is None and self.max_past() is not None:
|
||||
# In decode, not prefill, we're actually overwriting the KV-cache
|
||||
# in a circular buffer mode.
|
||||
# This makes sure the max_s for the decode pass is correct.
|
||||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
input_lengths_tensor=input_lengths,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
logits = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, None
|
||||
|
||||
# Copy inputs to the static inputs of the cuda graph
|
||||
# Static inputs are potentially padded
|
||||
cuda_graph["input_ids"][: input_ids.shape[0]] = input_ids
|
||||
cuda_graph["position_ids"][: position_ids.shape[-1]] = position_ids
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=batch.input_lengths,
|
||||
cache_lengths=batch.cache_lengths,
|
||||
input_lengths_tensor=batch.input_lengths_tensor,
|
||||
cache_lengths_tensor=batch.cache_lengths_tensor,
|
||||
max_current_length=batch.max_current_length,
|
||||
)
|
||||
# assert block_tables.shape[0] >= slots.shape[0]
|
||||
cuda_graph["block_tables"][: block_tables.shape[0]] = block_tables
|
||||
else:
|
||||
cuda_graph["block_tables"][
|
||||
: block_tables.shape[0], : block_tables.shape[1]
|
||||
] = block_tables
|
||||
|
||||
# XXX: This is working only because block 0 is reserved for the healthcheck
|
||||
# so it doesn't matter if we override it with bogus values.
|
||||
cuda_graph["slots"].fill_(0)
|
||||
cuda_graph["slots"][: slots.shape[0]] = slots
|
||||
cuda_graph["input_lengths"].zero_()
|
||||
cuda_graph["input_lengths"][: input_lengths.shape[0]] = input_lengths
|
||||
cuda_graph["cache_lengths"].zero_()
|
||||
cuda_graph["cache_lengths"][
|
||||
: cache_lengths_tensor.shape[0]
|
||||
] = cache_lengths_tensor
|
||||
|
||||
with self._forward_context(
|
||||
block_tables=cuda_graph["block_tables"],
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=cuda_graph["input_lengths"],
|
||||
cache_lengths_tensor=cuda_graph["cache_lengths"],
|
||||
state=cuda_graph["state"],
|
||||
):
|
||||
# Replay the graph
|
||||
cuda_graph["graph"].replay()
|
||||
|
||||
# Slice output to the correct shape
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, None
|
||||
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
bs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
block_tables = block_tables.reshape((bs, max_bt))
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
max_current_length=max_s,
|
||||
)
|
||||
else:
|
||||
if bs > max_bs:
|
||||
raise RuntimeError(
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
else:
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
block_tables_ptr = torch.zeros(
|
||||
bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
state = None
|
||||
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "model_type")
|
||||
and self.model.config.model_type == "qwen2_vl"
|
||||
):
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = self.model.get_position_ids(input_ids)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"cache_lengths": cache_lengths_tensor,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
del seqlen
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
logits = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = None
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
seqlen, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=1,
|
||||
max_k=seqlen,
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
)
|
Loading…
Reference in New Issue
Block a user