mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 00:12:08 +00:00
fix warmup issue for mllama
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8591687561
commit
29703dbd27
@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
# bsz, q_len, _ = hidden_states.size()
|
||||
(
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
cross_attention_len,
|
||||
indices,
|
||||
) = cross_attention_states
|
||||
bs = cu_seqlen_q.size(0) - 1
|
||||
bs = cross_attention_len.size(0)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||
query_states = self.q_norm(query_states)
|
||||
@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
||||
|
||||
indices = cross_attention_states[-1]
|
||||
out_hidden_states = hidden_states[:]
|
||||
if len(indices) > 0:
|
||||
assert max(indices) < hidden_states.shape[0]
|
||||
hidden_states = hidden_states[indices]
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||
lm_head_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
||||
cross_attention_states: Optional[torch.Tensor] = None,
|
||||
image_indices=None,
|
||||
indices=None,
|
||||
cross_attention_len: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if cross_attention_states is not None:
|
||||
seqlen_q = len(image_indices)
|
||||
n_images = cross_attention_states.shape[0]
|
||||
seqlen_k = cross_attention_states.shape[1]
|
||||
device = cross_attention_states.device
|
||||
if cu_seqlen_prefill is not None:
|
||||
offset = 0
|
||||
cu_q = []
|
||||
indices = []
|
||||
for index in image_indices:
|
||||
cu_q.append(offset)
|
||||
length = seqlen.input_lengths[index].item()
|
||||
assert index < seqlen.cu_seqlen_q.shape[0]
|
||||
input_ids_offset = seqlen.cu_seqlen_q[index]
|
||||
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
||||
offset += length
|
||||
cu_q.append(offset)
|
||||
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
||||
|
||||
assert max(indices) < input_ids.shape[0]
|
||||
|
||||
cu_seqlen_k = (
|
||||
torch.arange(
|
||||
n_images + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
else:
|
||||
cu_seqlen_q = torch.arange(
|
||||
seqlen_q + 1, device=device, dtype=torch.int32
|
||||
)
|
||||
seqlen_k = cross_attention_states.shape[1]
|
||||
n_images = cross_attention_states.shape[0]
|
||||
cu_seqlen_k = (
|
||||
torch.arange(
|
||||
n_images + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
* seqlen_k
|
||||
)
|
||||
indices = image_indices[:]
|
||||
|
||||
cross_attention_states = (
|
||||
cross_attention_states,
|
||||
cu_seqlen_q,
|
||||
cu_seqlen_k,
|
||||
cross_attention_len,
|
||||
indices,
|
||||
)
|
||||
|
||||
|
@ -1538,19 +1538,21 @@ class FlashCausalLM(Model):
|
||||
self.warmup_decode(batch_size, block_num, batch)
|
||||
synchronize(self.device)
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
|
||||
def warmup_prefill(
|
||||
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||
):
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
).repeat(batch_size)
|
||||
position_ids = torch.arange(
|
||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||
).repeat(batch_size)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).reshape(bs, -1)
|
||||
).reshape(batch_size, -1)
|
||||
slot_acc = []
|
||||
for i in range(bs):
|
||||
for i in range(batch_size):
|
||||
slots = []
|
||||
for b in block_tables[i]:
|
||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||
@ -1558,10 +1560,14 @@ class FlashCausalLM(Model):
|
||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.zeros(
|
||||
batch_size + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
|
||||
seqlen = Seqlen(
|
||||
|
@ -205,6 +205,24 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
||||
return batch
|
||||
|
||||
|
||||
def generate_cross_attention_states(
|
||||
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
|
||||
):
|
||||
if cross_attention_states is None:
|
||||
return None, None, None
|
||||
device = cross_attention_states.device
|
||||
indices_list = []
|
||||
if prefilling:
|
||||
for i in image_indices:
|
||||
indices_list.append(
|
||||
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
|
||||
)
|
||||
indices = torch.cat(indices_list, dim=0)
|
||||
else:
|
||||
indices = image_indices[:]
|
||||
return indices, seqlen.input_lengths.index_select(0, image_indices)
|
||||
|
||||
|
||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
def warmup_decode(
|
||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||
@ -255,6 +273,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
bucketing_ctx=None,
|
||||
)
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||
image_indices = image_indices.repeat(batch_size)
|
||||
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||
indices, cross_attention_len = generate_cross_attention_states(
|
||||
cross_attention_states, image_indices, seqlen, 1, False
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -262,26 +286,29 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
lm_head_indices=None,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=hpu_attention_meta,
|
||||
cross_attention_states=batch.cross_attention_states,
|
||||
image_indices=batch.image_indices[:],
|
||||
cross_attention_states=cross_attention_states,
|
||||
indices=indices,
|
||||
cross_attention_len=cross_attention_len,
|
||||
)
|
||||
|
||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
||||
def warmup_prefill(
|
||||
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
|
||||
):
|
||||
input_ids = torch.zeros(
|
||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
).repeat(batch_size)
|
||||
position_ids = torch.arange(
|
||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||
).repeat(bs)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
||||
).repeat(batch_size)
|
||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||
block_tables = torch.arange(
|
||||
max_bt, dtype=torch.int32, device=self.device
|
||||
).reshape(bs, -1)
|
||||
).reshape(batch_size, -1)
|
||||
slot_acc = []
|
||||
for i in range(bs):
|
||||
for i in range(batch_size):
|
||||
slots = []
|
||||
for b in block_tables[i]:
|
||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||
@ -289,10 +316,14 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||
|
||||
input_lengths = (
|
||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
||||
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(
|
||||
batch_size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
cu_seqlen_prefill = torch.zeros(
|
||||
batch_size + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||
|
||||
seqlen = Seqlen(
|
||||
@ -303,6 +334,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
lm_head_indices = input_lengths - 1
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||
image_indices = image_indices.repeat(batch_size)
|
||||
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||
indices, cross_attention_len = generate_cross_attention_states(
|
||||
cross_attention_states, image_indices, seqlen, prompt_len, True
|
||||
)
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -310,11 +347,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
kv_cache=self.kv_cache,
|
||||
slots=slots,
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
lm_head_indices=lm_head_indices,
|
||||
cross_attention_states=batch.cross_attention_states,
|
||||
adapter_data=None,
|
||||
hpu_attention_meta=None,
|
||||
image_indices=batch.image_indices[:],
|
||||
lm_head_indices=lm_head_indices,
|
||||
adapter_data=None,
|
||||
cross_attention_states=cross_attention_states,
|
||||
indices=indices,
|
||||
cross_attention_len=cross_attention_len,
|
||||
)
|
||||
|
||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||
@ -433,6 +471,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
else:
|
||||
padded_bs = input_lengths.shape[0]
|
||||
orig_bs = input_lengths.shape[0]
|
||||
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
|
||||
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||
if padded_bs != input_lengths.shape[0]:
|
||||
padded_input_lengths = F.pad(
|
||||
input_lengths,
|
||||
@ -454,26 +494,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
cache_lengths=padded_cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
input_seq = input_ids.view(orig_bs, -1)
|
||||
|
||||
input_ids = F.pad(
|
||||
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
|
||||
)
|
||||
position_ids = F.pad(
|
||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
||||
)
|
||||
slots = F.pad(
|
||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
||||
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
|
||||
)
|
||||
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
|
||||
if lm_head_indices is not None:
|
||||
lm_head_indices = F.pad(
|
||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||
)
|
||||
if cross_attention_states is not None:
|
||||
cross_attention_states = F.pad(
|
||||
cross_attention_states,
|
||||
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
|
||||
value=0,
|
||||
)
|
||||
if len(image_indices) != 0:
|
||||
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
|
||||
image_indices = torch.cat((image_indices, pad_indices), dim=0)
|
||||
else:
|
||||
seqlen = Seqlen(
|
||||
input_lengths=input_lengths,
|
||||
cache_lengths=cache_lengths_tensor,
|
||||
cu_seqlen_q=cu_seqlen_prefill,
|
||||
)
|
||||
|
||||
indices, cross_attention_len = generate_cross_attention_states(
|
||||
cross_attention_states,
|
||||
image_indices,
|
||||
seqlen,
|
||||
padded_input_len,
|
||||
batch.prefilling,
|
||||
)
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -483,10 +538,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||
seqlen=trim_seqlen_metadata(seqlen),
|
||||
hpu_attention_meta=batch.hpu_attn_meta,
|
||||
lm_head_indices=lm_head_indices,
|
||||
cross_attention_states=cross_attention_states,
|
||||
# TODO list
|
||||
adapter_data=None,
|
||||
image_indices=batch.image_indices[:],
|
||||
cross_attention_states=cross_attention_states,
|
||||
indices=indices,
|
||||
cross_attention_len=cross_attention_len,
|
||||
**kwargs,
|
||||
)
|
||||
if batch.pixel_values is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user