mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 08:22:07 +00:00
fix warmup issue for mllama
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
8591687561
commit
29703dbd27
@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module):
|
|||||||
# bsz, q_len, _ = hidden_states.size()
|
# bsz, q_len, _ = hidden_states.size()
|
||||||
(
|
(
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cross_attention_len,
|
||||||
cu_seqlen_k,
|
|
||||||
indices,
|
indices,
|
||||||
) = cross_attention_states
|
) = cross_attention_states
|
||||||
bs = cu_seqlen_q.size(0) - 1
|
bs = cross_attention_len.size(0)
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
query_states = query_states.view(bs, -1, self.num_heads, self.head_size)
|
||||||
query_states = self.q_norm(query_states)
|
query_states = self.q_norm(query_states)
|
||||||
@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module):
|
|||||||
|
|
||||||
indices = cross_attention_states[-1]
|
indices = cross_attention_states[-1]
|
||||||
out_hidden_states = hidden_states[:]
|
out_hidden_states = hidden_states[:]
|
||||||
if len(indices) > 0:
|
|
||||||
assert max(indices) < hidden_states.shape[0]
|
|
||||||
hidden_states = hidden_states[indices]
|
hidden_states = hidden_states[indices]
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module):
|
|||||||
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
|
||||||
lm_head_indices: Optional[torch.Tensor],
|
lm_head_indices: Optional[torch.Tensor],
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
# XXX: Putting these as optional so that the cuda warmup calls can go through.
|
|
||||||
cross_attention_states: Optional[torch.Tensor] = None,
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
image_indices=None,
|
indices=None,
|
||||||
|
cross_attention_len: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if cross_attention_states is not None:
|
if cross_attention_states is not None:
|
||||||
seqlen_q = len(image_indices)
|
|
||||||
n_images = cross_attention_states.shape[0]
|
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
|
||||||
device = cross_attention_states.device
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
offset = 0
|
|
||||||
cu_q = []
|
|
||||||
indices = []
|
|
||||||
for index in image_indices:
|
|
||||||
cu_q.append(offset)
|
|
||||||
length = seqlen.input_lengths[index].item()
|
|
||||||
assert index < seqlen.cu_seqlen_q.shape[0]
|
|
||||||
input_ids_offset = seqlen.cu_seqlen_q[index]
|
|
||||||
indices.extend(range(input_ids_offset, input_ids_offset + length))
|
|
||||||
offset += length
|
|
||||||
cu_q.append(offset)
|
|
||||||
cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32)
|
|
||||||
|
|
||||||
assert max(indices) < input_ids.shape[0]
|
|
||||||
|
|
||||||
cu_seqlen_k = (
|
|
||||||
torch.arange(
|
|
||||||
n_images + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
* seqlen_k
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cu_seqlen_q = torch.arange(
|
|
||||||
seqlen_q + 1, device=device, dtype=torch.int32
|
|
||||||
)
|
|
||||||
seqlen_k = cross_attention_states.shape[1]
|
|
||||||
n_images = cross_attention_states.shape[0]
|
|
||||||
cu_seqlen_k = (
|
|
||||||
torch.arange(
|
|
||||||
n_images + 1,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
* seqlen_k
|
|
||||||
)
|
|
||||||
indices = image_indices[:]
|
|
||||||
|
|
||||||
cross_attention_states = (
|
cross_attention_states = (
|
||||||
cross_attention_states,
|
cross_attention_states,
|
||||||
cu_seqlen_q,
|
cross_attention_len,
|
||||||
cu_seqlen_k,
|
|
||||||
indices,
|
indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1538,19 +1538,21 @@ class FlashCausalLM(Model):
|
|||||||
self.warmup_decode(batch_size, block_num, batch)
|
self.warmup_decode(batch_size, block_num, batch)
|
||||||
synchronize(self.device)
|
synchronize(self.device)
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch):
|
def warmup_prefill(
|
||||||
|
self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch
|
||||||
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(batch_size)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(batch_size)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
).reshape(bs, -1)
|
).reshape(batch_size, -1)
|
||||||
slot_acc = []
|
slot_acc = []
|
||||||
for i in range(bs):
|
for i in range(batch_size):
|
||||||
slots = []
|
slots = []
|
||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
@ -1558,10 +1560,14 @@ class FlashCausalLM(Model):
|
|||||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
batch_size, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
||||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
|
@ -205,6 +205,24 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling
|
||||||
|
):
|
||||||
|
if cross_attention_states is None:
|
||||||
|
return None, None, None
|
||||||
|
device = cross_attention_states.device
|
||||||
|
indices_list = []
|
||||||
|
if prefilling:
|
||||||
|
for i in image_indices:
|
||||||
|
indices_list.append(
|
||||||
|
torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device)
|
||||||
|
)
|
||||||
|
indices = torch.cat(indices_list, dim=0)
|
||||||
|
else:
|
||||||
|
indices = image_indices[:]
|
||||||
|
return indices, seqlen.input_lengths.index_select(0, image_indices)
|
||||||
|
|
||||||
|
|
||||||
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
class FlashMllamaCausalLM(FlashVlmCausalLM):
|
||||||
def warmup_decode(
|
def warmup_decode(
|
||||||
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch
|
||||||
@ -255,6 +273,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
bucketing_ctx=None,
|
bucketing_ctx=None,
|
||||||
)
|
)
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
|
image_indices = image_indices.repeat(batch_size)
|
||||||
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, 1, False
|
||||||
|
)
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -262,26 +286,29 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
|
hpu_attention_meta=hpu_attention_meta,
|
||||||
lm_head_indices=None,
|
lm_head_indices=None,
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
hpu_attention_meta=hpu_attention_meta,
|
cross_attention_states=cross_attention_states,
|
||||||
cross_attention_states=batch.cross_attention_states,
|
indices=indices,
|
||||||
image_indices=batch.image_indices[:],
|
cross_attention_len=cross_attention_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch):
|
def warmup_prefill(
|
||||||
|
self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch
|
||||||
|
):
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
prompt_len, dtype=batch.input_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(batch_size)
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
prompt_len, dtype=batch.position_ids.dtype, device=self.device
|
||||||
).repeat(bs)
|
).repeat(batch_size)
|
||||||
max_bt = (prompt_len // BLOCK_SIZE + 1) * bs
|
max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size
|
||||||
block_tables = torch.arange(
|
block_tables = torch.arange(
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
).reshape(bs, -1)
|
).reshape(batch_size, -1)
|
||||||
slot_acc = []
|
slot_acc = []
|
||||||
for i in range(bs):
|
for i in range(batch_size):
|
||||||
slots = []
|
slots = []
|
||||||
for b in block_tables[i]:
|
for b in block_tables[i]:
|
||||||
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE))
|
||||||
@ -289,10 +316,14 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device)
|
||||||
|
|
||||||
input_lengths = (
|
input_lengths = (
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len
|
torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len
|
||||||
|
)
|
||||||
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
batch_size, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
cu_seqlen_prefill = torch.zeros(
|
||||||
|
batch_size + 1, device=self.device, dtype=torch.int32
|
||||||
)
|
)
|
||||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
||||||
cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32)
|
|
||||||
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:])
|
||||||
|
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
@ -303,6 +334,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
lm_head_indices = input_lengths - 1
|
lm_head_indices = input_lengths - 1
|
||||||
|
|
||||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
|
image_indices = image_indices.repeat(batch_size)
|
||||||
|
cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1)
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states, image_indices, seqlen, prompt_len, True
|
||||||
|
)
|
||||||
self.model.forward(
|
self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -310,11 +347,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
kv_cache=self.kv_cache,
|
kv_cache=self.kv_cache,
|
||||||
slots=slots,
|
slots=slots,
|
||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
lm_head_indices=lm_head_indices,
|
|
||||||
cross_attention_states=batch.cross_attention_states,
|
|
||||||
adapter_data=None,
|
|
||||||
hpu_attention_meta=None,
|
hpu_attention_meta=None,
|
||||||
image_indices=batch.image_indices[:],
|
lm_head_indices=lm_head_indices,
|
||||||
|
adapter_data=None,
|
||||||
|
cross_attention_states=cross_attention_states,
|
||||||
|
indices=indices,
|
||||||
|
cross_attention_len=cross_attention_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch):
|
||||||
@ -433,6 +471,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
else:
|
else:
|
||||||
padded_bs = input_lengths.shape[0]
|
padded_bs = input_lengths.shape[0]
|
||||||
orig_bs = input_lengths.shape[0]
|
orig_bs = input_lengths.shape[0]
|
||||||
|
padded_input_len = input_ids.view(orig_bs, -1).shape[-1]
|
||||||
|
image_indices = torch.tensor(batch.image_indices, device=self.device)
|
||||||
if padded_bs != input_lengths.shape[0]:
|
if padded_bs != input_lengths.shape[0]:
|
||||||
padded_input_lengths = F.pad(
|
padded_input_lengths = F.pad(
|
||||||
input_lengths,
|
input_lengths,
|
||||||
@ -454,26 +494,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
cache_lengths=padded_cache_lengths_tensor,
|
cache_lengths=padded_cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
input_seq = input_ids.view(orig_bs, -1)
|
|
||||||
input_ids = F.pad(
|
input_ids = F.pad(
|
||||||
input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0
|
||||||
)
|
)
|
||||||
position_ids = F.pad(
|
position_ids = F.pad(
|
||||||
position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1
|
position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1
|
||||||
)
|
|
||||||
slots = F.pad(
|
|
||||||
slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0
|
|
||||||
)
|
)
|
||||||
|
slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
lm_head_indices = F.pad(
|
lm_head_indices = F.pad(
|
||||||
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
lm_head_indices, (0, padded_bs - orig_bs), value=0
|
||||||
)
|
)
|
||||||
|
if cross_attention_states is not None:
|
||||||
|
cross_attention_states = F.pad(
|
||||||
|
cross_attention_states,
|
||||||
|
(0, 0, 0, 0, 0, (padded_bs - orig_bs)),
|
||||||
|
value=0,
|
||||||
|
)
|
||||||
|
if len(image_indices) != 0:
|
||||||
|
pad_indices = torch.arange(orig_bs, padded_bs, device=self.device)
|
||||||
|
image_indices = torch.cat((image_indices, pad_indices), dim=0)
|
||||||
else:
|
else:
|
||||||
seqlen = Seqlen(
|
seqlen = Seqlen(
|
||||||
input_lengths=input_lengths,
|
input_lengths=input_lengths,
|
||||||
cache_lengths=cache_lengths_tensor,
|
cache_lengths=cache_lengths_tensor,
|
||||||
cu_seqlen_q=cu_seqlen_prefill,
|
cu_seqlen_q=cu_seqlen_prefill,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
indices, cross_attention_len = generate_cross_attention_states(
|
||||||
|
cross_attention_states,
|
||||||
|
image_indices,
|
||||||
|
seqlen,
|
||||||
|
padded_input_len,
|
||||||
|
batch.prefilling,
|
||||||
|
)
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
@ -483,10 +538,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM):
|
|||||||
seqlen=trim_seqlen_metadata(seqlen),
|
seqlen=trim_seqlen_metadata(seqlen),
|
||||||
hpu_attention_meta=batch.hpu_attn_meta,
|
hpu_attention_meta=batch.hpu_attn_meta,
|
||||||
lm_head_indices=lm_head_indices,
|
lm_head_indices=lm_head_indices,
|
||||||
cross_attention_states=cross_attention_states,
|
|
||||||
# TODO list
|
# TODO list
|
||||||
adapter_data=None,
|
adapter_data=None,
|
||||||
image_indices=batch.image_indices[:],
|
cross_attention_states=cross_attention_states,
|
||||||
|
indices=indices,
|
||||||
|
cross_attention_len=cross_attention_len,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if batch.pixel_values is not None:
|
if batch.pixel_values is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user