diff --git a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py index 216642e0..421a0a65 100644 --- a/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py +++ b/backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py @@ -681,11 +681,10 @@ class MllamaTextCrossAttention(nn.Module): # bsz, q_len, _ = hidden_states.size() ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) = cross_attention_states - bs = cu_seqlen_q.size(0) - 1 + bs = cross_attention_len.size(0) query_states = self.q_proj(hidden_states) query_states = query_states.view(bs, -1, self.num_heads, self.head_size) query_states = self.q_norm(query_states) @@ -814,8 +813,6 @@ class FlashLlamaCrossLayer(torch.nn.Module): indices = cross_attention_states[-1] out_hidden_states = hidden_states[:] - if len(indices) > 0: - assert max(indices) < hidden_states.shape[0] hidden_states = hidden_states[indices] residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -914,59 +911,14 @@ class FlashMllamaForConditionalGeneration(nn.Module): hpu_attention_meta: Optional[HPUPagedAttentionMetadata], lm_head_indices: Optional[torch.Tensor], adapter_data: Optional[torch.Tensor] = None, - # XXX: Putting these as optional so that the cuda warmup calls can go through. cross_attention_states: Optional[torch.Tensor] = None, - image_indices=None, + indices=None, + cross_attention_len: Optional[torch.Tensor] = None, ): if cross_attention_states is not None: - seqlen_q = len(image_indices) - n_images = cross_attention_states.shape[0] - seqlen_k = cross_attention_states.shape[1] - device = cross_attention_states.device - if cu_seqlen_prefill is not None: - offset = 0 - cu_q = [] - indices = [] - for index in image_indices: - cu_q.append(offset) - length = seqlen.input_lengths[index].item() - assert index < seqlen.cu_seqlen_q.shape[0] - input_ids_offset = seqlen.cu_seqlen_q[index] - indices.extend(range(input_ids_offset, input_ids_offset + length)) - offset += length - cu_q.append(offset) - cu_seqlen_q = torch.Tensor(cu_q).to(device=device, dtype=torch.int32) - - assert max(indices) < input_ids.shape[0] - - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - else: - cu_seqlen_q = torch.arange( - seqlen_q + 1, device=device, dtype=torch.int32 - ) - seqlen_k = cross_attention_states.shape[1] - n_images = cross_attention_states.shape[0] - cu_seqlen_k = ( - torch.arange( - n_images + 1, - device=device, - dtype=torch.int32, - ) - * seqlen_k - ) - indices = image_indices[:] - cross_attention_states = ( cross_attention_states, - cu_seqlen_q, - cu_seqlen_k, + cross_attention_len, indices, ) diff --git a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py index 334f004e..23a40016 100644 --- a/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/flash_causal_lm.py @@ -1538,19 +1538,21 @@ class FlashCausalLM(Model): self.warmup_decode(batch_size, block_num, batch) synchronize(self.device) - def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashCausalLMBatch): + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashCausalLMBatch + ): input_ids = torch.zeros( prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(bs) + ).repeat(batch_size) position_ids = torch.arange( prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(bs) - max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + ).reshape(batch_size, -1) slot_acc = [] - for i in range(bs): + for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) @@ -1558,10 +1560,14 @@ class FlashCausalLM(Model): slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( diff --git a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py index acd5d9a5..940ee1b0 100644 --- a/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py +++ b/backends/gaudi/server/text_generation_server/models/mllama_causal_lm.py @@ -205,6 +205,24 @@ class FlashMllamaCausalLMBatch(FlashVlmCausalLMBatch): return batch +def generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, pad_seq_len, prefilling +): + if cross_attention_states is None: + return None, None, None + device = cross_attention_states.device + indices_list = [] + if prefilling: + for i in image_indices: + indices_list.append( + torch.arange(pad_seq_len * i, pad_seq_len * (i + 1), device=device) + ) + indices = torch.cat(indices_list, dim=0) + else: + indices = image_indices[:] + return indices, seqlen.input_lengths.index_select(0, image_indices) + + class FlashMllamaCausalLM(FlashVlmCausalLM): def warmup_decode( self, batch_size: int, block_num: int, batch: FlashMllamaCausalLMBatch @@ -255,6 +273,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): bucketing_ctx=None, ) # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, 1, False + ) self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -262,26 +286,29 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), + hpu_attention_meta=hpu_attention_meta, lm_head_indices=None, adapter_data=None, - hpu_attention_meta=hpu_attention_meta, - cross_attention_states=batch.cross_attention_states, - image_indices=batch.image_indices[:], + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, ) - def warmup_prefill(self, prompt_len: int, bs: int, batch: FlashMllamaCausalLMBatch): + def warmup_prefill( + self, prompt_len: int, batch_size: int, batch: FlashMllamaCausalLMBatch + ): input_ids = torch.zeros( prompt_len, dtype=batch.input_ids.dtype, device=self.device - ).repeat(bs) + ).repeat(batch_size) position_ids = torch.arange( prompt_len, dtype=batch.position_ids.dtype, device=self.device - ).repeat(bs) - max_bt = (prompt_len // BLOCK_SIZE + 1) * bs + ).repeat(batch_size) + max_bt = (prompt_len // BLOCK_SIZE + 1) * batch_size block_tables = torch.arange( max_bt, dtype=torch.int32, device=self.device - ).reshape(bs, -1) + ).reshape(batch_size, -1) slot_acc = [] - for i in range(bs): + for i in range(batch_size): slots = [] for b in block_tables[i]: slots.extend(range(b * BLOCK_SIZE, (b + 1) * BLOCK_SIZE)) @@ -289,10 +316,14 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): slots = torch.tensor(slot_acc, dtype=batch.slots.dtype, device=self.device) input_lengths = ( - torch.ones(bs, dtype=torch.int32, device=self.device) * prompt_len + torch.ones(batch_size, dtype=torch.int32, device=self.device) * prompt_len + ) + cache_lengths_tensor = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + cu_seqlen_prefill = torch.zeros( + batch_size + 1, device=self.device, dtype=torch.int32 ) - cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) - cu_seqlen_prefill = torch.zeros(bs + 1, device=self.device, dtype=torch.int32) torch.cumsum(input_lengths, -1, out=cu_seqlen_prefill[1:]) seqlen = Seqlen( @@ -303,6 +334,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): lm_head_indices = input_lengths - 1 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. + image_indices = torch.tensor(batch.image_indices, device=self.device) + image_indices = image_indices.repeat(batch_size) + cross_attention_states = batch.cross_attention_states.repeat(batch_size, 1, 1) + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, image_indices, seqlen, prompt_len, True + ) self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -310,11 +347,12 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): kv_cache=self.kv_cache, slots=slots, seqlen=trim_seqlen_metadata(seqlen), - lm_head_indices=lm_head_indices, - cross_attention_states=batch.cross_attention_states, - adapter_data=None, hpu_attention_meta=None, - image_indices=batch.image_indices[:], + lm_head_indices=lm_head_indices, + adapter_data=None, + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, ) def warmup_hpu_graph(self, batch: FlashMllamaCausalLMBatch): @@ -433,6 +471,8 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): else: padded_bs = input_lengths.shape[0] orig_bs = input_lengths.shape[0] + padded_input_len = input_ids.view(orig_bs, -1).shape[-1] + image_indices = torch.tensor(batch.image_indices, device=self.device) if padded_bs != input_lengths.shape[0]: padded_input_lengths = F.pad( input_lengths, @@ -454,26 +494,41 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): cache_lengths=padded_cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) - input_seq = input_ids.view(orig_bs, -1) + input_ids = F.pad( - input_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + input_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=0 ) position_ids = F.pad( - position_ids, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=1 - ) - slots = F.pad( - slots, (0, (padded_bs - orig_bs) * input_seq.shape[-1]), value=0 + position_ids, (0, (padded_bs - orig_bs) * padded_input_len), value=1 ) + slots = F.pad(slots, (0, (padded_bs - orig_bs) * padded_input_len), value=0) if lm_head_indices is not None: lm_head_indices = F.pad( lm_head_indices, (0, padded_bs - orig_bs), value=0 ) + if cross_attention_states is not None: + cross_attention_states = F.pad( + cross_attention_states, + (0, 0, 0, 0, 0, (padded_bs - orig_bs)), + value=0, + ) + if len(image_indices) != 0: + pad_indices = torch.arange(orig_bs, padded_bs, device=self.device) + image_indices = torch.cat((image_indices, pad_indices), dim=0) else: seqlen = Seqlen( input_lengths=input_lengths, cache_lengths=cache_lengths_tensor, cu_seqlen_q=cu_seqlen_prefill, ) + + indices, cross_attention_len = generate_cross_attention_states( + cross_attention_states, + image_indices, + seqlen, + padded_input_len, + batch.prefilling, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -483,10 +538,11 @@ class FlashMllamaCausalLM(FlashVlmCausalLM): seqlen=trim_seqlen_metadata(seqlen), hpu_attention_meta=batch.hpu_attn_meta, lm_head_indices=lm_head_indices, - cross_attention_states=cross_attention_states, # TODO list adapter_data=None, - image_indices=batch.image_indices[:], + cross_attention_states=cross_attention_states, + indices=indices, + cross_attention_len=cross_attention_len, **kwargs, ) if batch.pixel_values is not None: