mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-11 20:34:54 +00:00
fix high dim
This commit is contained in:
parent
f843b62a44
commit
715b2d19ed
@ -44,7 +44,7 @@ def tgi_flash_attention_forward(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|
||||||
kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]]
|
kv_cache = kwargs["kv_cache"][module.layer_idx]
|
||||||
# This means no scale
|
# This means no scale
|
||||||
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
|
||||||
|
|
||||||
@ -97,7 +97,6 @@ def tgi_flash_attention_forward(
|
|||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# attn_output = attn_output.view(attn_output.shape[0], -1)
|
|
||||||
attn_output = attn_output.view(-1, num_heads * head_dim)
|
attn_output = attn_output.view(-1, num_heads * head_dim)
|
||||||
|
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
@ -244,6 +243,42 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _model_forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
cu_seqlen_prefill,
|
||||||
|
kv_cache,
|
||||||
|
block_tables,
|
||||||
|
slots,
|
||||||
|
seqlen,
|
||||||
|
max_s,
|
||||||
|
prefill_cache_indices,
|
||||||
|
lm_head_indices,
|
||||||
|
):
|
||||||
|
hidden_states = self.model.model.forward(
|
||||||
|
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
||||||
|
position_ids=position_ids[None, ...], # expand dim to easily fit transformers
|
||||||
|
past_key_values=None, # we use self.kv_cache instead of transformers cache object
|
||||||
|
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
||||||
|
return_dict=True,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=prefill_cache_indices,
|
||||||
|
kv_head_mapping=self.kv_head_mapping,
|
||||||
|
)[0].squeeze(dim=0)
|
||||||
|
# And compute logits from the lm_head, slicing correctly the indices
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits = self.model.lm_head.forward(hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -297,13 +332,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
max_q=batch.max_input_length,
|
max_q=batch.max_input_length,
|
||||||
max_k=batch.max_current_length,
|
max_k=batch.max_current_length,
|
||||||
)
|
)
|
||||||
# Use only the Model, not ModelForCausalLM
|
logits = self._model_forward(
|
||||||
hidden_states = self.model.model.forward(
|
input_ids=input_ids,
|
||||||
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
|
position_ids=position_ids,
|
||||||
position_ids=position_ids[None, ...],
|
|
||||||
past_key_values=None,
|
|
||||||
use_cache=False, # we use self.kv_cache instead of transformers cache object
|
|
||||||
return_dict=True,
|
|
||||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
@ -311,10 +342,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
seqlen=seqlen,
|
seqlen=seqlen,
|
||||||
max_s=max_s,
|
max_s=max_s,
|
||||||
prefill_cache_indices=batch.prefill_cache_indices,
|
prefill_cache_indices=batch.prefill_cache_indices,
|
||||||
kv_head_mapping=self.kv_head_mapping,
|
lm_head_indices=lm_head_indices,
|
||||||
)[0].squeeze(dim=0)
|
)
|
||||||
# And compute logits from the lm_head, slicing correctly the indices
|
|
||||||
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
|
|
||||||
if batch.prefill_cache_indices is not None:
|
if batch.prefill_cache_indices is not None:
|
||||||
batch.prefill_cache_indices = None
|
batch.prefill_cache_indices = None
|
||||||
return logits, None
|
return logits, None
|
||||||
@ -363,3 +392,180 @@ class TransformersFlashCausalLM(FlashCausalLM):
|
|||||||
# Slice output to the correct shape
|
# Slice output to the correct shape
|
||||||
logits = cuda_graph["logits"][:bs]
|
logits = cuda_graph["logits"][:bs]
|
||||||
return logits, None
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
|
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||||
|
input_lengths = [max_s] * bs
|
||||||
|
cache_lengths = [0] * bs
|
||||||
|
if max_bs is None:
|
||||||
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
|
input_lengths_tensor = (
|
||||||
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
bs, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
max_current_length=max_s,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if bs > max_bs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
|
)
|
||||||
|
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||||
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
|
else:
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||||
|
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||||
|
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||||
|
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||||
|
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
|
create_decode_state_cuda_graphs,
|
||||||
|
)
|
||||||
|
|
||||||
|
block_tables_ptr = torch.zeros(
|
||||||
|
bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
|
||||||
|
state = create_decode_state_cuda_graphs(
|
||||||
|
device=input_ids.device,
|
||||||
|
block_tables=block_tables,
|
||||||
|
block_tables_ptr=block_tables_ptr,
|
||||||
|
last_page_len=last_page_len,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state = None
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(self.model, "config")
|
||||||
|
and hasattr(self.model.config, "model_type")
|
||||||
|
and self.model.config.model_type == "qwen2_vl"
|
||||||
|
):
|
||||||
|
if position_ids.dim() == 1:
|
||||||
|
position_ids = self.model.get_position_ids(input_ids)
|
||||||
|
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
self.cuda_graphs[bs] = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
"kv_cache": self.kv_cache,
|
||||||
|
"block_tables": block_tables,
|
||||||
|
"slots": slots,
|
||||||
|
"input_lengths": input_lengths_tensor,
|
||||||
|
"cache_lengths": cache_lengths_tensor,
|
||||||
|
"state": state,
|
||||||
|
"graph": graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# Run once outside to warmup
|
||||||
|
with self._forward_context(
|
||||||
|
block_tables=block_tables,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
state=state,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
self._model_forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
del seqlen
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.cuda.graph(graph, pool=MEM_POOL):
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths_tensor,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=None,
|
||||||
|
max_q=1,
|
||||||
|
max_k=max_s,
|
||||||
|
)
|
||||||
|
logits = self._model_forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=None,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
lm_head_indices=None,
|
||||||
|
)
|
||||||
|
self.cuda_graphs[bs]["logits"] = logits
|
||||||
|
self.cuda_graphs[bs]["speculative_logits"] = None
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def tunableop_warmup(self, seqlen: int):
|
||||||
|
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||||
|
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
seqlen, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.tensor(
|
||||||
|
[0, seqlen], device=self.device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
max_s = seqlen
|
||||||
|
seqlen = Seqlen(
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths_tensor,
|
||||||
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
|
max_q=1,
|
||||||
|
max_k=seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
self._model_forward(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=self.kv_cache,
|
||||||
|
block_tables=None,
|
||||||
|
seqlen=seqlen,
|
||||||
|
slots=slots,
|
||||||
|
max_s=max_s,
|
||||||
|
lm_head_indices=None,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user