fix high dim

This commit is contained in:
Cyril Vallez 2024-12-13 13:25:56 +00:00
parent f843b62a44
commit 715b2d19ed

View File

@ -44,7 +44,7 @@ def tgi_flash_attention_forward(
**kwargs,
):
kv_cache = kwargs["kv_cache"][kwargs["layer_idx"]]
kv_cache = kwargs["kv_cache"][module.layer_idx]
# This means no scale
kv_scales=KVScales(torch.tensor(1., device=key_states.device), torch.tensor(1., device=key_states.device))
@ -97,7 +97,6 @@ def tgi_flash_attention_forward(
softcap=softcap,
)
# attn_output = attn_output.view(attn_output.shape[0], -1)
attn_output = attn_output.view(-1, num_heads * head_dim)
return attn_output, None
@ -244,6 +243,42 @@ class TransformersFlashCausalLM(FlashCausalLM):
trust_remote_code=trust_remote_code,
)
def _model_forward(
self,
input_ids,
position_ids,
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
seqlen,
max_s,
prefill_cache_indices,
lm_head_indices,
):
hidden_states = self.model.model.forward(
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
position_ids=position_ids[None, ...], # expand dim to easily fit transformers
past_key_values=None, # we use self.kv_cache instead of transformers cache object
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
kv_head_mapping=self.kv_head_mapping,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]
logits = self.model.lm_head.forward(hidden_states)
return logits
def forward(
self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
@ -297,13 +332,9 @@ class TransformersFlashCausalLM(FlashCausalLM):
max_q=batch.max_input_length,
max_k=batch.max_current_length,
)
# Use only the Model, not ModelForCausalLM
hidden_states = self.model.model.forward(
input_ids=input_ids[None, ...], # expand dim to easily fit transformers
position_ids=position_ids[None, ...],
past_key_values=None,
use_cache=False, # we use self.kv_cache instead of transformers cache object
return_dict=True,
logits = self._model_forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
@ -311,10 +342,8 @@ class TransformersFlashCausalLM(FlashCausalLM):
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
kv_head_mapping=self.kv_head_mapping,
)[0].squeeze(dim=0)
# And compute logits from the lm_head, slicing correctly the indices
logits = self.model.lm_head.forward(hidden_states[lm_head_indices])
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, None
@ -363,3 +392,180 @@ class TransformersFlashCausalLM(FlashCausalLM):
# Slice output to the correct shape
logits = cuda_graph["logits"][:bs]
return logits, None
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
input_lengths = [max_s] * bs
cache_lengths = [0] * bs
if max_bs is None:
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths_tensor = (
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
)
cache_lengths_tensor = torch.zeros(
bs, dtype=torch.int32, device=self.device
)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
slots = self.cuda_graphs[max_bs]["slots"][:bs]
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
if ATTENTION == "flashinfer":
from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs,
)
block_tables_ptr = torch.zeros(
bs + 1, dtype=torch.int32, device=self.device
)
last_page_len = torch.ones(bs, dtype=torch.int32, device=self.device)
state = create_decode_state_cuda_graphs(
device=input_ids.device,
block_tables=block_tables,
block_tables_ptr=block_tables_ptr,
last_page_len=last_page_len,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
)
else:
state = None
if (
hasattr(self.model, "config")
and hasattr(self.model.config, "model_type")
and self.model.config.model_type == "qwen2_vl"
):
if position_ids.dim() == 1:
position_ids = self.model.get_position_ids(input_ids)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs] = {
"input_ids": input_ids,
"position_ids": position_ids,
"kv_cache": self.kv_cache,
"block_tables": block_tables,
"slots": slots,
"input_lengths": input_lengths_tensor,
"cache_lengths": cache_lengths_tensor,
"state": state,
"graph": graph,
}
torch.cuda.synchronize()
# Run once outside to warmup
with self._forward_context(
block_tables=block_tables,
cu_seqlen_prefill=None,
input_lengths_tensor=input_lengths_tensor,
state=state,
cache_lengths_tensor=cache_lengths_tensor,
):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
self._model_forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
del seqlen
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
seqlen = Seqlen(
input_lengths=input_lengths_tensor,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=None,
max_q=1,
max_k=max_s,
)
logits = self._model_forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=None,
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = None
torch.cuda.synchronize()
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.zeros(
seqlen, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
)
max_s = seqlen
seqlen = Seqlen(
input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=1,
max_k=seqlen,
)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self._model_forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=self.kv_cache,
block_tables=None,
seqlen=seqlen,
slots=slots,
max_s=max_s,
lm_head_indices=None,
prefill_cache_indices=None,
)