fix warmup issue for mllama

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi A 2025-04-04 05:42:59 -07:00
parent 8591687561
commit 29703dbd27
3 changed files with 100 additions and 86 deletions

View File

@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module):
# bsz, q_len, _ = hidden_states.size() # bsz, q_len, _ = hidden_states.size()
( (
cross_attention_states, cross_attention_states,
cu_seqlen_q, cross_attention_len,
cu_seqlen_k,
indices, indices,
) = cross_attention_states ) = cross_attention_states
bs = cu_seqlen_q.size(0) - 1 bs = cross_attention_len.size(0)
query_states = self.q_proj(hidden_states) query_states = self.q_proj(hidden_states)
query_states = query_states.view(bs, -1, self.num_heads, self.head_size) query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
query_states = self.q_norm(query_states) query_states = self.q_norm(query_states)
@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
indices = cross_attention_states[-1] indices = cross_attention_states[-1]
out_hidden_states = hidden_states[:] out_hidden_states = hidden_states[:]
if len(indices) > 0:
assert max(indices) < hidden_states.shape[0]
hidden_states = hidden_states[indices] hidden_states = hidden_states[indices]
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module):
hpu_attention_meta: Optional[HPUPagedAttentionMetadata], hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
lm_head_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
# XXX: Putting these as optional so that the cuda warmup calls can go through.
cross_attention_states: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None,
image_indices=None, indices=None,
cross_attention_len: Optional[torch.Tensor] = None,
): ):
if cross_attention_states is not None: if cross_attention_states is not None:
seqlen_q = len(image_indices)
n_images = cross_attention_states.shape[0]
seqlen_k = cross_attention_states.shape[1]
device = cross_attention_states.device
if cu_seqlen_prefill is not None:
offset = 0
cu_q = []
indices = []
for index in image_indices:
cu_q.append(offset)
length = seqlen.input_lengths[index].item()
assert index < seqlen.cu_seqlen_q.shape[0]
input_ids_offset = seqlen.cu_seqlen_q[index]
indices.extend(range(input_ids_offset, input_ids_offset + length))
offset += length
cu_q.append(offset)
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
assert max(indices) < input_ids.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
else:
cu_seqlen_q = torch.arange(
seqlen_q + 1, device=device, dtype=torch.int32
)
seqlen_k = cross_attention_states.shape[1]
n_images = cross_attention_states.shape[0]
cu_seqlen_k = (
torch.arange(
n_images + 1,
device=device,
dtype=torch.int32,
)
* seqlen_k
)
indices = image_indices[:]
cross_attention_states = ( cross_attention_states = (
cross_attention_states, cross_attention_states,
cu_seqlen_q, cross_attention_len,
cu_seqlen_k,
indices, indices,
) )

View File

@ -1538,19 +1538,21 @@ class FlashCausalLM(Model):
self.warmup_decode(batch_size, block_num, batch) self.warmup_decode(batch_size, block_num, batch)
synchronize(self.device) synchronize(self.device)
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch): def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
):
input_ids = torch.zeros( input_ids = torch.zeros(
prompt_len, dtype=batch.input_ids.dtype, device=self.device prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs) ).repeat(batch_size)
position_ids = torch.arange( position_ids = torch.arange(
prompt_len, dtype=batch.position_ids.dtype, device=self.device prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs) ).repeat(batch_size)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange( block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device max_bt, dtype=torch.int32, device=self.device
).reshape(bs, -1) ).reshape(batch_size, -1)
slot_acc = [] slot_acc = []
for i in range(bs): for i in range(batch_size):
slots = [] slots = []
for b in block_tables[i]: for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
@ -1558,10 +1560,14 @@ class FlashCausalLM(Model):
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = ( input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(
batch_size, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
) )
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen( seqlen = Seqlen(

View File

@ -205,6 +205,24 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
return batch return batch
def generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
):
if cross_attention_states is None:
return None, None, None
device = cross_attention_states.device
indices_list = []
if prefilling:
for i in image_indices:
indices_list.append(
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
)
indices = torch.cat(indices_list, dim=0)
else:
indices = image_indices[:]
return indices, seqlen.input_lengths.index_select(0, image_indices)
class FlashMllamaCausalLM(FlashVlmCausalLM): class FlashMllamaCausalLM(FlashVlmCausalLM):
def warmup_decode( def warmup_decode(
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
@ -255,6 +273,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
bucketing_ctx=None, bucketing_ctx=None,
) )
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, 1, False
)
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -262,26 +286,29 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=hpu_attention_meta,
lm_head_indices=None, lm_head_indices=None,
adapter_data=None, adapter_data=None,
hpu_attention_meta=hpu_attention_meta, cross_attention_states=cross_attention_states,
cross_attention_states=batch.cross_attention_states, indices=indices,
image_indices=batch.image_indices[:], cross_attention_len=cross_attention_len,
) )
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): def warmup_prefill(
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
):
input_ids = torch.zeros( input_ids = torch.zeros(
prompt_len, dtype=batch.input_ids.dtype, device=self.device prompt_len, dtype=batch.input_ids.dtype, device=self.device
).repeat(bs) ).repeat(batch_size)
position_ids = torch.arange( position_ids = torch.arange(
prompt_len, dtype=batch.position_ids.dtype, device=self.device prompt_len, dtype=batch.position_ids.dtype, device=self.device
).repeat(bs) ).repeat(batch_size)
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
block_tables = torch.arange( block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device max_bt, dtype=torch.int32, device=self.device
).reshape(bs, -1) ).reshape(batch_size, -1)
slot_acc = [] slot_acc = []
for i in range(bs): for i in range(batch_size):
slots = [] slots = []
for b in block_tables[i]: for b in block_tables[i]:
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
@ -289,10 +316,14 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
input_lengths = ( input_lengths = (
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
)
cache_lengths_tensor = torch.zeros(
batch_size, dtype=torch.int32, device=self.device
)
cu_seqlen_prefill = torch.zeros(
batch_size + 1, device=self.device, dtype=torch.int32
) )
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
seqlen = Seqlen( seqlen = Seqlen(
@ -303,6 +334,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
lm_head_indices = input_lengths - 1 lm_head_indices = input_lengths - 1
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
image_indices = torch.tensor(batch.image_indices, device=self.device)
image_indices = image_indices.repeat(batch_size)
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states, image_indices, seqlen, prompt_len, True
)
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -310,11 +347,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
kv_cache=self.kv_cache, kv_cache=self.kv_cache,
slots=slots, slots=slots,
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
lm_head_indices=lm_head_indices,
cross_attention_states=batch.cross_attention_states,
adapter_data=None,
hpu_attention_meta=None, hpu_attention_meta=None,
image_indices=batch.image_indices[:], lm_head_indices=lm_head_indices,
adapter_data=None,
cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
) )
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
@ -433,6 +471,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
else: else:
padded_bs = input_lengths.shape[0] padded_bs = input_lengths.shape[0]
orig_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0]
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
image_indices = torch.tensor(batch.image_indices, device=self.device)
if padded_bs != input_lengths.shape[0]: if padded_bs != input_lengths.shape[0]:
padded_input_lengths = F.pad( padded_input_lengths = F.pad(
input_lengths, input_lengths,
@ -454,26 +494,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
cache_lengths=padded_cache_lengths_tensor, cache_lengths=padded_cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
) )
input_seq = input_ids.view(orig_bs, -1)
input_ids = F.pad( input_ids = F.pad(
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
) )
position_ids = F.pad( position_ids = F.pad(
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
)
slots = F.pad(
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
) )
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
if lm_head_indices is not None: if lm_head_indices is not None:
lm_head_indices = F.pad( lm_head_indices = F.pad(
lm_head_indices, (0, padded_bs - orig_bs), value=0 lm_head_indices, (0, padded_bs - orig_bs), value=0
) )
if cross_attention_states is not None:
cross_attention_states = F.pad(
cross_attention_states,
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
value=0,
)
if len(image_indices) != 0:
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
image_indices = torch.cat((image_indices, pad_indices), dim=0)
else: else:
seqlen = Seqlen( seqlen = Seqlen(
input_lengths=input_lengths, input_lengths=input_lengths,
cache_lengths=cache_lengths_tensor, cache_lengths=cache_lengths_tensor,
cu_seqlen_q=cu_seqlen_prefill, cu_seqlen_q=cu_seqlen_prefill,
) )
indices, cross_attention_len = generate_cross_attention_states(
cross_attention_states,
image_indices,
seqlen,
padded_input_len,
batch.prefilling,
)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,
@ -483,10 +538,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
seqlen=trim_seqlen_metadata(seqlen), seqlen=trim_seqlen_metadata(seqlen),
hpu_attention_meta=batch.hpu_attn_meta, hpu_attention_meta=batch.hpu_attn_meta,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
cross_attention_states=cross_attention_states,
# TODO list # TODO list
adapter_data=None, adapter_data=None,
image_indices=batch.image_indices[:], cross_attention_states=cross_attention_states,
indices=indices,
cross_attention_len=cross_attention_len,
**kwargs, **kwargs,
) )
if batch.pixel_values is not None: if batch.pixel_values is not None: