mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 12:24:53 +00:00
fix high dim
This commit is contained in:
parent
f843b62a44
commit
715b2d19ed
@ -44,7 +44,7 @@ def tgi_flash_attention_forward(
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]]
|
||||
kv_cache = kwargs["kv_cache"][module.layer_idx]
|
||||
# This means no scale
|
||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||
|
||||
@ -97,7 +97,6 @@ def tgi_flash_attention_forward(
|
||||
softcap=softcap,
|
||||
)
|
||||
|
||||
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||
|
||||
return attn_output, None
|
||||
@ -244,6 +243,42 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
def _model_forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
block_tables,
|
||||
slots,
|
||||
seqlen,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
lm_head_indices,
|
||||
):
|
||||
hidden_states = self.model.model.forward(
|
||||
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
||||
position_ids=position_ids[None, ...], # expand dim to easily fit transformers
|
||||
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
return_dict=True,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
)[0].squeeze(dim=0)
|
||||
# And compute logits from the lm_head, slicing correctly the indices
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits = self.model.lm_head.forward(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
def forward(
|
||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
@ -297,13 +332,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
max_q=batch.max_input_length,
|
||||
max_k=batch.max_current_length,
|
||||
)
|
||||
# Use only the Model, not ModelForCausalLM
|
||||
hidden_states = self.model.model.forward(
|
||||
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
||||
position_ids=position_ids[None, ...],
|
||||
past_key_values=None,
|
||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||
return_dict=True,
|
||||
logits = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
@ -311,10 +342,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
kv_head_mapping=self.kv_head_mapping,
|
||||
)[0].squeeze(dim=0)
|
||||
# And compute logits from the lm_head, slicing correctly the indices
|
||||
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, None
|
||||
@ -363,3 +392,180 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
||||
# Slice output to the correct shape
|
||||
logits = cuda_graph["logits"][:bs]
|
||||
return logits, None
|
||||
|
||||
|
||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||
input_lengths = [max_s] * bs
|
||||
cache_lengths = [0] * bs
|
||||
if max_bs is None:
|
||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||
input_lengths_tensor = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
bs, dtype=torch.int32, device=self.device
|
||||
)
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).repeat(bs)
|
||||
block_tables = block_tables.reshape((bs, max_bt))
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = block_tables_to_ragged(
|
||||
block_tables=block_tables,
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
max_current_length=max_s,
|
||||
)
|
||||
else:
|
||||
if bs > max_bs:
|
||||
raise RuntimeError(
|
||||
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||
)
|
||||
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||
if ATTENTION == "flashinfer":
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||
else:
|
||||
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||
|
||||
if ATTENTION == "flashinfer":
|
||||
from text_generation_server.layers.attention.flashinfer import (
|
||||
create_decode_state_cuda_graphs,
|
||||
)
|
||||
|
||||
block_tables_ptr = torch.zeros(
|
||||
bs + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||
state = create_decode_state_cuda_graphs(
|
||||
device=input_ids.device,
|
||||
block_tables=block_tables,
|
||||
block_tables_ptr=block_tables_ptr,
|
||||
last_page_len=last_page_len,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
)
|
||||
else:
|
||||
state = None
|
||||
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "model_type")
|
||||
and self.model.config.model_type == "qwen2_vl"
|
||||
):
|
||||
if position_ids.dim() == 1:
|
||||
position_ids = self.model.get_position_ids(input_ids)
|
||||
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs] = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"kv_cache": self.kv_cache,
|
||||
"block_tables": block_tables,
|
||||
"slots": slots,
|
||||
"input_lengths": input_lengths_tensor,
|
||||
"cache_lengths": cache_lengths_tensor,
|
||||
"state": state,
|
||||
"graph": graph,
|
||||
}
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
with self._forward_context(
|
||||
block_tables=block_tables,
|
||||
cu_seqlen_prefill=None,
|
||||
input_lengths_tensor=input_lengths_tensor,
|
||||
state=state,
|
||||
cache_lengths_tensor=cache_lengths_tensor,
|
||||
):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
del seqlen
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths_tensor,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=None,
|
||||
max_q=1,
|
||||
max_k=max_s,
|
||||
)
|
||||
logits = self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=None,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
seqlen=seqlen,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = None
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def tunableop_warmup(self, seqlen: int):
|
||||
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
seqlen, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.tensor(
|
||||
[0, seqlen], device=self.device, dtype=torch.int32
|
||||
)
|
||||
max_s = seqlen
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
max_q=1,
|
||||
max_k=seqlen,
|
||||
)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self._model_forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=self.kv_cache,
|
||||
block_tables=None,
|
||||
seqlen=seqlen,
|
||||
slots=slots,
|
||||
max_s=max_s,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
)
|
Loading…
Reference in New Issue
Block a user